Skip to content
Merged
36 changes: 29 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ project(InfiniOps LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# Internal variable to control pybind11's automatic optimization flags (like `-flto`).
set(PYBIND11_ENABLE_EXTRAS ON)

# Options for backends.
option(WITH_CPU "Enable CPU backend" OFF)
option(WITH_NVIDIA "Enable CUDA backend" OFF)
Expand Down Expand Up @@ -32,11 +35,26 @@ if(AUTO_DETECT_DEVICES)
message(STATUS "Auto-detected Iluvatar environment.")
endif()

# TODO: Please test and uncomment/update the auto-detection for MetaX.
# if(DEFINED ENV{MACA_PATH})
# set(WITH_METAX ON)
# message(STATUS "Auto-detected MetaX environment.")
# endif()
if(DEFINED ENV{MACA_PATH})
set(WITH_METAX ON)
message(STATUS "Auto-detected MetaX environment from MACA_PATH")
else()
execute_process(
COMMAND sh -c "grep -h 9999 /sys/bus/pci/devices/*/vendor 2>/dev/null"
OUTPUT_VARIABLE _pci_vendor_output
OUTPUT_STRIP_TRAILING_WHITESPACE
)

string(FIND "${_pci_vendor_output}" "9999" _found_pos)

if(_found_pos GREATER -1)
set(WITH_METAX ON)
message(STATUS "Detected MetaX GPU from PCI vendor ID 0x9999")
else()
set(WITH_METAX OFF)
message(STATUS "No MetaX GPU detected")
endif()
endif()
endif()

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)
Expand Down Expand Up @@ -75,8 +93,8 @@ if(WITH_METAX)

# Normally can be found at: `/opt/maca/`.
set(MACA_PATH $ENV{MACA_PATH})
set(CMAKE_C_COMPILER ${MACA_PATH}/mxgpu_llvm/bin/mxcc)
set(CMAKE_CXX_COMPILER ${MACA_PATH}/mxgpu_llvm/bin/mxcc)
set(CMAKE_C_COMPILER ${CMAKE_CURRENT_SOURCE_DIR}/scripts/mxcc_wrapper.sh)
set(CMAKE_CXX_COMPILER ${CMAKE_CURRENT_SOURCE_DIR}/scripts/mxcc_wrapper.sh)

include_directories("${MACA_PATH}/include")
link_directories("${MACA_PATH}/lib")
Expand All @@ -92,6 +110,10 @@ if(NOT WITH_NVIDIA AND NOT WITH_ILUVATAR AND NOT WITH_METAX)
add_compile_definitions(WITH_CPU=1)
endif()

if(WITH_METAX)
set(PYBIND11_ENABLE_EXTRAS OFF)
endif()

add_subdirectory(src)

add_subdirectory(examples)
24 changes: 24 additions & 0 deletions scripts/mxcc_wrapper.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash
# Filter out flags unsupported by `mxcc`.
ARGS=()
skip_next=0
for arg in "$@"; do
if [ $skip_next -eq 1 ]; then
skip_next=0
continue
fi
case "$arg" in
-pthread)
;;
-B)
skip_next=1
;;
-B*)
;;
*)
ARGS+=("$arg")
;;
esac
done

exec ${MACA_PATH}/mxgpu_llvm/bin/mxcc "${ARGS[@]}"
6 changes: 5 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ if(GENERATE_PYTHON_BINDINGS)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG)

pybind11_add_module(ops ${PYBIND11_SOURCES})
if(PYBIND11_ENABLE_EXTRAS)
pybind11_add_module(ops ${PYBIND11_SOURCES})
else()
pybind11_add_module(ops NO_EXTRAS ${PYBIND11_SOURCES})
endif()

target_include_directories(ops PRIVATE ${PROJECT_SOURCE_DIR})
target_link_libraries(ops PRIVATE infiniops)
Expand Down
4 changes: 2 additions & 2 deletions src/base/add.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class Add : public Operator<Add> {
is_other_contiguous_{other.IsContiguous()},
is_out_contiguous_{out.IsContiguous()} {
assert(!out.HasBroadcastDim() &&
"The output of `Add` should NOT have broadcasted dim!");
"the output of `Add` should NOT have broadcasted dim!");
// TODO(lzm): support mix-precision later using the generic elementwise
// framework.
assert(input_type_ == other_type_ && other_type_ == out_type_ &&
"Operator `Add` requires all input and output Tensors to have the "
"operator `Add` requires all input and output Tensors to have the "
"same dtype");
}

Expand Down
2 changes: 1 addition & 1 deletion src/common/constexpr_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ struct ConstexprMap {
if (pr.first == key) return pr.second;
}
// TODO(lzm): change to logging.
assert("ConstexprMap's key is not found!");
assert("the key is not found in the `ConstexprMap`");
// Unreachable, provided to satisfy the compiler's requirement.
std::abort();
}
Expand Down
148 changes: 105 additions & 43 deletions src/common/traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,101 +6,163 @@

namespace infini::ops {

// --------------------- List and TypePack ---------------------
// A generic container for a sequence of compile-time values.
template <auto... Items>
template <auto... items>
struct List {};

// `ListGet<index>(List<items...>{})` extracts the `i`th value from a `List`
// tag.
template <std::size_t index, auto head, auto... tail>
constexpr auto ListGetImpl(List<head, tail...>) {
if constexpr (index == 0)
return head;
else
return ListGetImpl<index - 1>(List<tail...>{});
}

template <std::size_t index, auto... items>
constexpr auto ListGet(List<items...> list) {
return ListGetImpl<index>(list);
}

template <typename... Ts>
struct TypePack {};

// -----------------------------------------------------------------------------
// Tags
// -----------------------------------------------------------------------------
// Tags are passed as regular function arguments to user functors instead of
// template parameters. This lets users write plain C++17 `[](auto tag)` lambdas
// rather than C++20 template lambdas (`[]<typename T>()`).

// `TypeTag<T>`: carries a C++ type. Recover with `typename
// decltype(tag)::type`.
template <typename T>
struct TypeTag {
using type = T;
};

// `ValueTag<V>`: carries a compile-time value. Recover with
// `decltype(tag)::value`.
template <auto v>
struct ValueTag {
using value_type = decltype(v);
static constexpr auto value = v;
};

// -----------------------------------------------------------------------------
// List Queries
// -----------------------------------------------------------------------------

// Check at compile-time if a Value exists within a construct (e.g., List<>).
// Example: static_assert(ContainsValue<SupportedTiles, 32>);
template <typename T, auto Value>
// Check at compile-time if a value exists within a construct (e.g., `List<>`).
// Example: `static_assert(ContainsValue<SupportedTiles, 32>)`;
template <typename T, auto value>
struct Contains;

template <auto Value, auto... Items>
struct Contains<List<Items...>, Value>
: std::disjunction<std::bool_constant<Value == Items>...> {};
template <auto value, auto... items>
struct Contains<List<items...>, value>
: std::disjunction<std::bool_constant<value == items>...> {};

template <typename T, auto Value>
inline constexpr bool ContainsValue = Contains<T, Value>::value;
template <typename T, auto value>
inline constexpr bool ContainsValue = Contains<T, value>::value;

// Check at compile-time if a type T is present in a variadic list of types Ts.
// Example: static_assert(IsTypeInList<T, float, int>);
// Check at compile-time if a type `T` is present in a variadic list of types
// `Ts`.
// Example: `static_assert(IsTypeInList<T, float, int>)`;
template <typename T, typename... Ts>
inline constexpr bool IsTypeInList = (std::is_same_v<T, Ts> || ...);

// Trait to detect whether `T` is a `List<...>` specialization.
template <typename T>
struct IsListType : std::false_type {};

template <auto... items>
struct IsListType<List<items...>> : std::true_type {};

// -----------------------------------------------------------------------------
// List Operations
// -----------------------------------------------------------------------------

// Concatenates two List types into a single List.
// Example: ConcatType<List<1, 2>, List<3, 4>> is List<1, 2, 3, 4>.
// Concatenates two List types into a single `List`.
// Example: `ConcatType<List<1, 2>, List<3, 4>>` is `List<1, 2, 3, 4>`.
template <typename L1, typename L2>
struct Concat;

template <auto... I1, auto... I2>
struct Concat<List<I1...>, List<I2...>> {
using type = List<I1..., I2...>;
template <auto... item1, auto... item2>
struct Concat<List<item1...>, List<item2...>> {
using type = List<item1..., item2...>;
};

template <typename L1, typename L2>
using ConcatType = typename Concat<L1, L2>::type;

template <typename... Lists>
struct Flatten;

template <auto... items>
struct Flatten<List<items...>> {
using type = List<items...>;
};

template <typename L1, typename L2, typename... Rest>
struct Flatten<L1, L2, Rest...> {
using type = typename Flatten<ConcatType<L1, L2>, Rest...>::type;
};

// -----------------------------------------------------------------------------
// Invocability Detection (SFINAE)
// -----------------------------------------------------------------------------

// Checks if a Functor's template operator()<Value> can be called with Args.
template <typename Functor, auto Value, typename = void, typename... Args>
// Checks if a `Functor` can be called with a `ValueTag<Value>` and `Args...`.
template <typename Functor, auto value, typename = void, typename... Args>
struct IsInvocable : std::false_type {};

template <typename Functor, auto Value, typename... Args>
struct IsInvocable<
Functor, Value,
std::void_t<decltype(std::declval<Functor>().template operator()<Value>(
std::declval<Args>()...))>,
Args...> : std::true_type {};
template <typename Functor, auto value, typename... Args>
struct IsInvocable<Functor, value,
std::void_t<decltype(std::declval<Functor>()(
ValueTag<value>{}, std::declval<Args>()...))>,
Args...> : std::true_type {};

template <typename Functor, auto Value, typename... Args>
template <typename Functor, auto value, typename... Args>
inline constexpr bool IsInvocableValue =
IsInvocable<Functor, Value, void, Args...>::value;
IsInvocable<Functor, value, void, Args...>::value;

// -----------------------------------------------------------------------------
// Filtering Logic
// -----------------------------------------------------------------------------

// Recursive template to filter values based on Functor support at compile-time.
// Recursive template to filter values based on `Functor` support at
// compile-time.
template <typename Functor, typename ArgsTuple, typename Result,
auto... Remaining>
auto... remaining>
struct Filter;

// Base case: All values processed.
template <typename Functor, typename... Args, auto... Filtered>
struct Filter<Functor, std::tuple<Args...>, List<Filtered...>> {
using type = List<Filtered...>;
template <typename Functor, typename... Args, auto... filtered>
struct Filter<Functor, std::tuple<Args...>, List<filtered...>> {
using type = List<filtered...>;
};

// Recursive step: Test the 'Head' value and accumulate if supported.
template <typename Functor, typename... Args, auto... Filtered, auto Head,
auto... Tail>
struct Filter<Functor, std::tuple<Args...>, List<Filtered...>, Head, Tail...> {
// Recursive step: Test the `head` value and accumulate if supported.
template <typename Functor, typename... Args, auto... filtered, auto head,
auto... tail>
struct Filter<Functor, std::tuple<Args...>, List<filtered...>, head, tail...> {
using type = typename std::conditional_t<
IsInvocableValue<Functor, Head, Args...> &&
!ContainsValue<List<Filtered...>, Head>,
Filter<Functor, std::tuple<Args...>, List<Filtered..., Head>, Tail...>,
Filter<Functor, std::tuple<Args...>, List<Filtered...>, Tail...>>::type;
IsInvocableValue<Functor, head, Args...> &&
!ContainsValue<List<filtered...>, head>,
Filter<Functor, std::tuple<Args...>, List<filtered..., head>, tail...>,
Filter<Functor, std::tuple<Args...>, List<filtered...>, tail...>>::type;
};

// Interface to filter a List type directly.
// Interface to filter a `List` type directly.
template <typename Functor, typename ArgsTuple, typename ListType>
struct FilterList;

template <typename Functor, typename... Args, auto... Items>
struct FilterList<Functor, std::tuple<Args...>, List<Items...>> {
template <typename Functor, typename... Args, auto... items>
struct FilterList<Functor, std::tuple<Args...>, List<items...>> {
using type =
typename Filter<Functor, std::tuple<Args...>, List<>, Items...>::type;
typename Filter<Functor, std::tuple<Args...>, List<>, items...>::type;
};

} // namespace infini::ops
Expand Down
10 changes: 7 additions & 3 deletions src/cpu/add/add.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ class Operator<Add, Device::Type::kCpu> : public Add {
void operator()(void* stream, const Tensor input, const Tensor other,
Tensor out) const override {
DispatchFunc<ConcatType<FloatTypes, AllIntTypes>>(
out_type_, [&]<typename T>() { compute<T>(stream, input, other, out); },
"Operator<Add, Device::Type::kCpu>::operator()");
out_type_,
[&](auto tag) {
using T = typename decltype(tag)::type;
Compute<T>(stream, input, other, out);
},
"`Operator<Add, Device::Type::kCpu>::operator()`");
}

private:
template <typename T>
void compute(void* stream, const Tensor input, const Tensor other,
void Compute(void* stream, const Tensor input, const Tensor other,
Tensor out) const {
const auto* input_ptr = static_cast<const T*>(input.data());
const auto* other_ptr = static_cast<const T*>(other.data());
Expand Down
3 changes: 2 additions & 1 deletion src/cuda/add/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ class CudaAdd : public Add {
Tensor out) const override {
DispatchFunc<AllTypes>(
out_type_,
[&]<typename T>() {
[&](auto tag) {
using T = typename decltype(tag)::type;
// TODO(lzm): currently hard-code block_size to be 256.
dim3 blockDims(
std::min(static_cast<Tensor::Size>(256), output_size_));
Expand Down
3 changes: 2 additions & 1 deletion src/cuda/rms_norm/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class CudaRmsNorm : public RmsNorm {

DispatchFunc<DataType::kFloat32, DataType::kFloat16, DataType::kBFloat16>(
out.dtype(),
[&]<typename T>() {
[&](auto tag) {
using T = typename decltype(tag)::type;
RmsNormKernel<kBlockSize, float, T, T>
<<<num_blocks, kBlockSize, 0, cuda_stream>>>(
reinterpret_cast<T*>(out.data()), stride_out_batch,
Expand Down
Loading