Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ option(WITH_CPU "Enable CPU backend" OFF)
option(WITH_NVIDIA "Enable CUDA backend" OFF)
option(WITH_ILUVATAR "Enable Iluvatar GPU backend" OFF)
option(WITH_METAX "Enable MetaX backend" OFF)
option(WITH_CAMBRICON "Enable Cambricon backend" OFF)

option(AUTO_DETECT_DEVICES "Automatically detect available devices" OFF)
option(GENERATE_PYTHON_BINDINGS "Generate Python bindings" OFF)
Expand All @@ -32,6 +33,10 @@ if(AUTO_DETECT_DEVICES)
message(STATUS "Auto-detected Iluvatar environment.")
endif()

if(DEFINED ENV{NEUWARE_HOME})
set(WITH_CAMBRICON ON)
message(STATUS "Auto-detected Cambricon environment.")
endif()
# TODO: Please test and uncomment/update the auto-detection for MetaX.
# if(DEFINED ENV{MACA_PATH})
# set(WITH_METAX ON)
Expand Down Expand Up @@ -72,7 +77,22 @@ if(WITH_METAX)
find_library(MACA_BLAS_LIB NAMES mcblas HINTS "${MACA_PATH}/lib" REQUIRED)
endif()

# If no GPU platform is enabled, CPU is enabled by default.
if(WITH_CAMBRICON)
add_compile_definitions(WITH_CAMBRICON=1)
set(NEUWARE_HOME $ENV{NEUWARE_HOME})

include_directories("${NEUWARE_HOME}/include")
link_directories("${NEUWARE_HOME}/lib")
link_directories("${NEUWARE_HOME}/lib64")

# Libraries: cnrt / cnnl / cnnl_extra / cnpapi.
find_library(CAMBRICON_RUNTIME_LIB NAMES cnrt HINTS "${NEUWARE_HOME}/lib64" REQUIRED)
find_library(CAMBRICON_CNNL_LIB NAMES cnnl HINTS "${NEUWARE_HOME}/lib64" REQUIRED)
find_library(CAMBRICON_CNNL_EXTRA_LIB NAMES cnnl_extra HINTS "${NEUWARE_HOME}/lib64" REQUIRED)
find_library(CAMBRICON_PAPI_LIB NAMES cnpapi HINTS "${NEUWARE_HOME}/lib64" REQUIRED)
endif()

# If all other platforms are not enabled, CPU is enabled by default.
if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX)
add_compile_definitions(WITH_CPU=1)
endif()
Expand Down
3 changes: 3 additions & 0 deletions examples/gemm/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
#if WITH_METAX
#include "metax/gemm/mcblas.h"
#endif
#if WITH_CAMBRICON
#include "cambricon/gemm/cnblas.h"
#endif

#include "runtime_api.h"
#include "tensor.h"
Expand Down
9 changes: 9 additions & 0 deletions examples/runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
#define DEVICE_MEMCPY_HOST_TO_DEVICE mcMemcpyHostToDevice
#define DEVICE_MEMCPY_DEVICE_TO_HOST mcMemcpyDeviceToHost
#define DEFAULT_DEVICE_TYPE Device::Type::kMetax
#elif WITH_CAMBRICON
#include <cnrt.h>
#define DEVICE_MALLOC cnrtMalloc
#define DEVICE_FREE cnrtFree
#define DEVICE_MEMCPY cnrtMemcpy
#define DEVICE_MEMSET cnrtMemset
#define DEVICE_MEMCPY_HOST_TO_DEVICE cnrtMemcpyHostToDev
#define DEVICE_MEMCPY_DEVICE_TO_HOST cnrtMemcpyDevToHost
#define DEFAULT_DEVICE_TYPE Device::Type::kCambricon
#elif WITH_CPU
#include <cstdlib>
#include <cstring>
Expand Down
10 changes: 10 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ if(WITH_METAX)
list(APPEND DEVICE_LIST "metax")
endif()


if(WITH_CAMBRICON)
target_compile_definitions(infiniops PUBLIC WITH_CAMBRICON=1)

target_include_directories(infiniops PUBLIC "${NEUWARE_HOME}/include")
target_link_libraries(infiniops PUBLIC ${CAMBRICON_RUNTIME_LIB} ${CAMBRICON_CNNL_LIB} ${CAMBRICON_CNNL_EXTRA_LIB} ${CAMBRICON_PAPI_LIB})

list(APPEND DEVICE_LIST "cambricon")
endif()

target_include_directories(infiniops PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

if(GENERATE_PYTHON_BINDINGS)
Expand Down
25 changes: 25 additions & 0 deletions src/cambricon/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef INFINI_OPS_CAMBRICON_COMMON_H_
#define INFINI_OPS_CAMBRICON_COMMON_H_

#include <cnnl.h>

#include "data_type.h"

namespace infini::ops::cnnl_utils {

inline cnnlDataType_t GetDtype(DataType dtype) {
switch (dtype) {
case DataType::kFloat16:
return CNNL_DTYPE_HALF;
case DataType::kFloat32:
return CNNL_DTYPE_FLOAT;
case DataType::kInt32:
return CNNL_DTYPE_INT32;
default:
return CNNL_DTYPE_INVALID;
}
}

} // namespace infini::ops::cnnl_utils

#endif
153 changes: 153 additions & 0 deletions src/cambricon/gemm/cnblas.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
#ifndef INFINI_OPS_CAMBRICON_GEMM_CNBLAS_H_
#define INFINI_OPS_CAMBRICON_GEMM_CNBLAS_H_

#include <cassert>
#include <memory>
#include <vector>

#include <cnnl.h>
#include <cnrt.h>

#include "base/gemm.h"
#include "cambricon/common.h"

namespace infini::ops {

template <>
class Operator<Gemm, Device::Type::kCambricon> : public Gemm {
public:
Operator(const Tensor a, const Tensor b, std::optional<float> alpha,
std::optional<float> beta, std::optional<int> trans_a,
std::optional<int> trans_b, Tensor c)
: Gemm{a, b, alpha, beta, trans_a, trans_b, c},
a_rows_{a.size(-2)}, a_cols_{a.size(-1)},
b_rows_{b.size(-2)}, b_cols_{b.size(-1)},
c_rows_{c.size(-2)}, c_cols_{c.size(-1)} {
// Currently only support non-transposed matrices
assert(!trans_a_ && "trans_a=true is not currently supported");
assert(!trans_b_ && "trans_b=true is not currently supported");
// Create CNNL handle
cnnlCreate(&cnnl_handle_);

// Create tensor descriptors
cnnlCreateTensorDescriptor(&desc_a_);
cnnlCreateTensorDescriptor(&desc_b_);
cnnlCreateTensorDescriptor(&desc_c_);

// Create matmul descriptor and algo
cnnlCreateMatMulDescriptor(&matmul_desc_);
cnnlCreateMatMulAlgo(&matmul_algo_);
cnnlCreateMatMulHeuristicResult(&heuristic_result_);

// Set stride usage
int32_t use_stride = 1;
cnnlSetMatMulDescAttr(matmul_desc_, CNNL_MATMUL_USE_STRIDE, &use_stride,
sizeof(int32_t));

// Setup tensor descriptors using physical dimensions
SetupTensorDescriptor(desc_a_, a_strides_, a_type_, a_rows_, a_cols_,
batch_count_, batch_stride_a_);
SetupTensorDescriptor(desc_b_, b_strides_, b_type_, b_rows_, b_cols_,
batch_count_, batch_stride_b_);
SetupTensorDescriptor(desc_c_, c_strides_, c_type_, c_rows_, c_cols_,
batch_count_, batch_stride_c_);
int count = 0;
cnnlGetBatchMatMulExAlgoHeuristic(
cnnl_handle_,
matmul_desc_, desc_a_, desc_b_, desc_c_,
NULL, 1, &heuristic_result_, &count);
cnnlGetBatchMatMulExHeuristicResult(heuristic_result_, matmul_algo_, &workspace_size_);
}

Operator(const Tensor a, const Tensor b, Tensor c)
: Operator{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
c} {}

Operator(const Tensor a, const Tensor b, std::optional<float> alpha,
std::optional<float> beta, Tensor c)
: Operator{a, b, alpha, beta, std::nullopt, std::nullopt, c} {}

~Operator() {
cnnlDestroyTensorDescriptor(desc_a_);
cnnlDestroyTensorDescriptor(desc_b_);
cnnlDestroyTensorDescriptor(desc_c_);
cnnlDestroyMatMulDescriptor(matmul_desc_);
cnnlDestroyMatMulAlgo(matmul_algo_);
cnnlDestroyMatMulHeuristicResult(heuristic_result_);
cnnlDestroy(cnnl_handle_);
}

void operator()(void* stream, const Tensor a, const Tensor b,
std::optional<float> alpha, std::optional<float> beta,
std::optional<int> trans_a, std::optional<int> trans_b,
Tensor c) const override {
const auto& alpha_value{alpha.value_or(alpha_)};
const auto& beta_value{beta.value_or(beta_)};

// Set queue for this execution
cnnlSetQueue(cnnl_handle_, (cnrtQueue_t)stream);

// Allocate workspace using pre-computed size
void* workspace = nullptr;
if (workspace_size_ > 0) {
cnrtMalloc(&workspace, workspace_size_);
}

// Execute batch matrix multiply
cnnlBatchMatMulEx(
cnnl_handle_, matmul_desc_, matmul_algo_, &alpha_value, desc_a_,
a.data(), desc_b_,
b.data(), &beta_value, desc_c_,
c.data(), workspace, workspace_size_);

// Cleanup workspace
if (workspace) {
cnrtFree(workspace);
}
cnrtQueueSync((cnrtQueue_t)stream);
}

private:
void SetupTensorDescriptor(cnnlTensorDescriptor_t desc,
const Tensor::Strides& strides, DataType dtype,
Tensor::Size rows, Tensor::Size cols,
Tensor::Size batch, Tensor::Stride batch_stride) {
cnnlDataType_t cnnl_dtype = cnnl_utils::GetDtype(dtype);

if (batch > 1) {
// Batched tensor: [batch, rows, cols]
std::vector<int> dims = {static_cast<int>(batch), static_cast<int>(rows), static_cast<int>(cols)};
std::vector<int> strides_arr = {
static_cast<int>(batch_stride),
static_cast<int>(strides[strides.size() - 2]),
static_cast<int>(strides[strides.size() - 1])};
cnnlSetTensorDescriptorEx(desc, CNNL_LAYOUT_ARRAY, cnnl_dtype,
dims.size(), dims.data(), strides_arr.data());
} else {
// 2D tensor: [rows, cols]
std::vector<int> dims = {static_cast<int>(rows), static_cast<int>(cols)};
std::vector<int> strides_arr = {
static_cast<int>(strides[strides.size() - 2]),
static_cast<int>(strides[strides.size() - 1])};
cnnlSetTensorDescriptorEx(desc, CNNL_LAYOUT_ARRAY, cnnl_dtype,
dims.size(), dims.data(), strides_arr.data());
}
}

cnnlHandle_t cnnl_handle_;
cnnlTensorDescriptor_t desc_a_;
cnnlTensorDescriptor_t desc_b_;
cnnlTensorDescriptor_t desc_c_;
cnnlMatMulDescriptor_t matmul_desc_;
cnnlMatMulAlgo_t matmul_algo_;
cnnlMatMulHeuristicResult_t heuristic_result_;

// Physical storage dimensions for each tensor
Tensor::Size a_rows_, a_cols_;
Tensor::Size b_rows_, b_cols_;
Tensor::Size c_rows_, c_cols_;
};

} // namespace infini::ops

#endif
3 changes: 3 additions & 0 deletions src/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ class Operator : public OperatorBase {
auto operator()(const Tensor tensor, Args&&... args) const {
return operator()(stream_, tensor, std::forward<Args>(args)...);
}

protected:
size_t workspace_size_{0};
};

} // namespace infini::ops
Expand Down
4 changes: 4 additions & 0 deletions tests/test_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def test_gemm(
rtol,
atol,
):
# Skip trans test for MLU platform as it is not currently supported
if device == "mlu" and (trans_a or trans_b):
pytest.skip("目前MLU平台上的GEMM算子不支持trans相关设置")

a = randn_strided(a_shape, a_strides, dtype=dtype, device=device)
b = randn_strided(b_shape, b_strides, dtype=dtype, device=device)

Expand Down