diff --git a/cub/benchmarks/bench/segmented_topk/fixed/keys.cu b/cub/benchmarks/bench/segmented_topk/fixed/keys.cu index 41edc8dcc35..b8f13469dce 100644 --- a/cub/benchmarks/bench/segmented_topk/fixed/keys.cu +++ b/cub/benchmarks/bench/segmented_topk/fixed/keys.cu @@ -4,6 +4,7 @@ #include #include +#include #include #include @@ -44,33 +45,13 @@ void fixed_seg_size_topk_keys( nvbench::state& state, nvbench::type_list, nvbench::enum_type>) { - // Range of guaranteed total number of items - constexpr auto min_num_total_items = 1; - constexpr auto max_num_total_items = ::cuda::std::numeric_limits<::cuda::std::int32_t>::max(); - - // Static segment size - using seg_size_t = cub::detail::batched_topk::segment_size_static; - - // Static k (number of selected output elements per segment) - using k_value_t = cub::detail::batched_topk::k_static; - - // Static selection direction (max) - using select_direction_value_t = cub::detail::batched_topk::select_direction_static; - - // Number of segments is a host-accessible value - using num_segments_uniform_t = cub::detail::batched_topk::num_segments_uniform<>; - - // Total number of items guarantee type - using total_num_items_guarantee_t = - cub::detail::batched_topk::total_num_items_guarantee; - // Retrieve axis parameters const auto max_elements = static_cast(state.get_int64("Elements{io}")); const auto segment_size = static_cast<::cuda::std::ptrdiff_t>(MaxSegmentSize); const auto selected_elements = static_cast<::cuda::std::ptrdiff_t>(MaxNumSelected); const auto num_segments = ::cuda::std::max(1, (max_elements / segment_size)); const auto elements = num_segments * segment_size; - const auto total_num_items = total_num_items_guarantee_t{static_cast<::cuda::std::int64_t>(elements)}; + const auto total_num_items = ::cuda::__argument::__immediate{static_cast<::cuda::std::int64_t>(elements)}; const bit_entropy entropy = str_to_entropy(state.get_string("Entropy")); // Skip workloads where k exceeds the segment size @@ -87,9 +68,9 @@ void fixed_seg_size_topk_keys( auto d_keys_in = cuda::make_strided_iterator(cuda::make_counting_iterator(d_keys_in_ptr), segment_size); auto d_keys_out = cuda::make_strided_iterator(cuda::make_counting_iterator(d_keys_out_ptr), selected_elements); - auto segment_sizes = seg_size_t{}; - auto k = k_value_t{}; - auto select_directions = select_direction_value_t{}; + auto segment_sizes = ::cuda::__argument::__constant{}; + auto k = ::cuda::__argument::__constant{}; + auto select_direction = ::cuda::__argument::__constant{}; state.add_element_count(elements, "NumElements"); state.add_element_count(segment_size, "SegmentSize"); @@ -117,8 +98,8 @@ void fixed_seg_size_topk_keys( static_cast(nullptr), segment_sizes, k, - select_directions, - num_segments_uniform_t{static_cast<::cuda::std::int64_t>(num_segments)}, + select_direction, + ::cuda::__argument::__immediate{static_cast<::cuda::std::int64_t>(num_segments)}, total_num_items, env); }); diff --git a/cub/benchmarks/bench/segmented_topk/variable/keys.cu b/cub/benchmarks/bench/segmented_topk/variable/keys.cu index c4a41e66a33..3db3da44976 100644 --- a/cub/benchmarks/bench/segmented_topk/variable/keys.cu +++ b/cub/benchmarks/bench/segmented_topk/variable/keys.cu @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -171,20 +172,17 @@ void variable_seg_size_topk_keys(nvbench::state& state, static_cast(MaxSegmentSize)); const auto input_elements = thrust::reduce(d_segment_sizes.begin(), d_segment_sizes.end()); const auto output_elements = static_cast(num_segments) * K; - const auto total_num_items = - cub::detail::batched_topk::total_num_items_guarantee<1, cuda::std::numeric_limits::max()>{ - static_cast(input_elements)}; + const auto total_num_items = ::cuda::__argument::__immediate{static_cast(input_elements)}; auto in_keys_buffer = gen_data( num_segments, string_to_pattern(state.get_string("Pattern")), thrust::raw_pointer_cast(d_segment_sizes.data())); auto out_keys_buffer = thrust::device_vector(output_elements, thrust::no_init); - cub::detail::batched_topk::segment_size_per_segment segment_sizes_param{ - thrust::raw_pointer_cast(d_segment_sizes.data())}; - cub::detail::batched_topk::k_static k_param{}; - cub::detail::batched_topk::select_direction_static select_directions{}; - cub::detail::batched_topk::num_segments_uniform<> num_segments_uniform_param{ - static_cast(num_segments)}; + auto segment_sizes_param = ::cuda::__argument::__immediate_sequence{ + thrust::raw_pointer_cast(d_segment_sizes.data()), ::cuda::__argument::__bounds<1, MaxSegmentSize>()}; + auto k_param = ::cuda::__argument::__constant{}; + auto select_direction = ::cuda::__argument::__constant{}; + auto num_segments_param = ::cuda::__argument::__immediate{static_cast(num_segments)}; auto d_keys_in = cuda::make_strided_iterator( cuda::make_counting_iterator(thrust::raw_pointer_cast(in_keys_buffer.data())), @@ -210,8 +208,8 @@ void variable_seg_size_topk_keys(nvbench::state& state, static_cast(nullptr), segment_sizes_param, k_param, - select_directions, - num_segments_uniform_param, + select_direction, + num_segments_param, total_num_items, env); }); diff --git a/cub/cub/agent/agent_batched_topk.cuh b/cub/cub/agent/agent_batched_topk.cuh index 8540376feba..0bdaa8f38ef 100644 --- a/cub/cub/agent/agent_batched_topk.cuh +++ b/cub/cub/agent/agent_batched_topk.cuh @@ -23,6 +23,7 @@ #include #include +#include #include CUB_NAMESPACE_BEGIN @@ -72,8 +73,8 @@ struct agent_batched_topk_worker_per_segment using key_t = it_value_t; using value_t = it_value_t; - using segment_size_val_t = typename SegmentSizeParameterT::value_type; - using num_segments_val_t = typename NumSegmentsParameterT::value_type; + using segment_size_val_t = typename ::cuda::__argument::__traits::element_type; + using num_segments_val_t = typename ::cuda::__argument::__traits::element_type; using counters_t = batched_topk_counters; static constexpr auto policy = PolicyGetter{}(); @@ -94,7 +95,7 @@ struct agent_batched_topk_worker_per_segment multi_worker_per_segment_policy.threads_per_block * multi_worker_per_segment_policy.items_per_thread; // Check if there could be large segments present - static constexpr bool only_small_segments = params::static_max_value_v <= tile_size; + static constexpr bool only_small_segments = ::cuda::__argument::__traits::max <= tile_size; // Check if we are dealing with keys-only or key-value pairs static constexpr bool is_keys_only = ::cuda::std::is_same_v; @@ -190,16 +191,16 @@ struct agent_batched_topk_worker_per_segment // Boundary check // TODO (elstehle): consider skipping boundary check if we can safely assume the right grid dimensions - if (segment_id >= num_segments.get_param(0)) + if (segment_id >= params::get_param(num_segments, 0)) { return; } - constexpr bool is_full_tile = params::has_single_static_value_v - && params::static_min_value_v == tile_size; + constexpr bool is_full_tile = ::cuda::__argument::__traits::is_constant + && ::cuda::__argument::__traits::lowest == tile_size; // Resolve Segment Parameters - const auto segment_size = segment_sizes.get_param(segment_id); + const auto segment_size = params::get_param(segment_sizes, segment_id); if (!only_small_segments && segment_size > tile_size) { // Enqueue large segment @@ -215,8 +216,8 @@ struct agent_batched_topk_worker_per_segment else { // Process small segment - const auto k = (::cuda::std::min) (k_param.get_param(segment_id), - static_cast(segment_size)); + const auto k = (::cuda::std::min) (params::get_param(k_param, segment_id), + static_cast(segment_size)); const auto direction = select_directions.get_param(segment_id); // Determine padding key based on direction diff --git a/cub/cub/detail/segmented_params.cuh b/cub/cub/detail/segmented_params.cuh index 696f61c1797..fe5cc5c9162 100644 --- a/cub/cub/detail/segmented_params.cuh +++ b/cub/cub/detail/segmented_params.cuh @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #pragma once @@ -13,127 +13,93 @@ # pragma system_header #endif // no system header -#include +#include #include +#include #include -#include -#include +#include CUB_NAMESPACE_BEGIN namespace detail::params { -// ----------------------------------------------------------------------------- -// Parameter Mixins and Helpers -// ----------------------------------------------------------------------------- - -// Allows providing constrains on parameter values at compile time -template ::lowest(), T Max = ::cuda::std::numeric_limits::max()> -struct static_bounds_mixin +// ===================================================================== +// get_param — unified segment parameter access +// ===================================================================== + +//! @brief Returns the value of an argument for a given segment index. +//! +//! @param[in] __arg Argument or argument wrapper to read. +//! @param[in] __index Segment index to read for sequence arguments. +//! @return The single argument value, or the sequence element at the given index. +_CCCL_TEMPLATE(class _Tp, class _SegmentIndexT) +_CCCL_REQUIRES((!::cuda::__argument::__is_wrapper_v<::cuda::std::remove_cvref_t<_Tp>>) ) +[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto get_param(_Tp&& __arg, [[maybe_unused]] _SegmentIndexT __index) noexcept { - static_assert(Min <= Max, "Min must be <= Max"); - - // Compile-time bounds - static constexpr T static_min_value = Min; - static constexpr T static_max_value = Max; - - // Indicates that there's only one possible value - static constexpr bool is_exact = (Min == Max); -}; + if constexpr (::cuda::__argument::__traits<::cuda::std::remove_cvref_t<_Tp>>::is_single_value) + { + return __arg; + } + else + { + return __arg[__index]; + } +} -// Allows specifying a list of supported options for a parameter. E.g., the orders (ascending, descending) that are -// supported by a sorting algorithm. -template -struct supported_options +template +[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto +get_param(const ::cuda::__argument::__constant<_Value>& __arg, [[maybe_unused]] _SegmentIndexT __index) noexcept { - static constexpr size_t count = sizeof...(Options); -}; - -// ----------------------------------------------------------------------------- -// Fundamental Parameter Types -// ----------------------------------------------------------------------------- + return ::cuda::__argument::__unwrap(__arg); +} -// A compile-time constant -template -struct static_constant_param : public static_bounds_mixin +template +[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto +get_param(const ::cuda::__argument::__constant_sequence<_Value>& __arg, _SegmentIndexT __index) noexcept { - using value_type = T; + return ::cuda::__argument::__unwrap(__arg)[__index]; +} - template - _CCCL_HOST_DEVICE constexpr auto get_param([[maybe_unused]] SegmentIndexT segment_id) const - { - static_assert(static_bounds_mixin::is_exact, "Static parameter must have exact value"); - return static_bounds_mixin::static_min_value; - } -}; -// ----------------------------------------------------------------------------- -// 1. Uniform Param -// ----------------------------------------------------------------------------- -// Added default template args so CTAD can deduce T and default Min/Max -template ::lowest(), T Max = ::cuda::std::numeric_limits::max()> -struct uniform_param : public static_bounds_mixin +template +[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto get_param( + const ::cuda::__argument::__immediate<_Arg, _StaticBounds>& __arg, [[maybe_unused]] _SegmentIndexT __index) noexcept { - using value_type = T; - - T value; - - _CCCL_HOST_DEVICE constexpr uniform_param(T v) - : value(v) - {} - - uniform_param() = default; - - template - _CCCL_HOST_DEVICE constexpr auto get_param([[maybe_unused]] SegmentIndexT segment_id) const - { - return value; - } -}; + return ::cuda::__argument::__unwrap(__arg); +} -template -uniform_param(T) -> uniform_param; - -// ----------------------------------------------------------------------------- -// 2. Per-Segment Param -// ----------------------------------------------------------------------------- -// Added defaults for T, Min, and Max based on the Iterator's value_type -template ::value_type, - T Min = ::cuda::std::numeric_limits::lowest(), - T Max = ::cuda::std::numeric_limits::max()> -struct per_segment_param : public static_bounds_mixin +template +[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto +get_param(const ::cuda::__argument::__immediate_sequence<_Arg, _StaticBounds>& __arg, _SegmentIndexT __index) noexcept { - using iterator_type = IteratorT; - using value_type = T; + return ::cuda::__argument::__unwrap(__arg)[__index]; +} - IteratorT iterator; - T min_value = Min; - T max_value = Max; +template +[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto get_param( + const ::cuda::__argument::__deferred<_Arg, _StaticBounds>& __arg, [[maybe_unused]] _SegmentIndexT __index) noexcept +{ + return ::cuda::__argument::__unwrap(__arg); +} - _CCCL_HOST_DEVICE constexpr per_segment_param(IteratorT iter, T min_v = Min, T max_v = Max) - : iterator(iter) - , min_value(min_v) - , max_value(max_v) - {} +template +[[nodiscard]] _CCCL_HOST_DEVICE constexpr auto +get_param(const ::cuda::__argument::__deferred_sequence<_Arg, _StaticBounds>& __arg, _SegmentIndexT __index) noexcept +{ + return ::cuda::__argument::__unwrap(__arg)[__index]; +} - per_segment_param() = default; +// ===================================================================== +// Discrete parameter support +// ===================================================================== - template - _CCCL_HOST_DEVICE constexpr auto get_param(SegmentIndexT segment_id) const - { - return iterator[segment_id]; - } +//! @brief Specifies a list of supported options for a parameter. +template +struct supported_options +{ + static constexpr ::cuda::std::size_t count = sizeof...(Options); }; -// Deduction Guide: -// Allows: per_segment_param{iter} -> per_segment_param -template -per_segment_param(IteratorT) -> per_segment_param; - -// ----------------------------------------------------------------------------- -// 3. Uniform Discrete Param -// ----------------------------------------------------------------------------- +//! @brief Uniform discrete parameter — a single runtime value with a known set of supported options. template struct uniform_discrete_param { @@ -155,9 +121,7 @@ struct uniform_discrete_param } }; -// ----------------------------------------------------------------------------- -// 4. Per-Segment Discrete Param -// ----------------------------------------------------------------------------- +//! @brief Per-segment discrete parameter — per-segment values with a known set of supported options. template struct per_segment_discrete_param { @@ -180,71 +144,40 @@ struct per_segment_discrete_param } }; -// ----------------------------------------------------------------------------- -// Parameter Type Helpers -// ----------------------------------------------------------------------------- -template -inline constexpr bool is_static_param_v = false; - -template -inline constexpr bool is_static_param_v> = true; - -template -inline constexpr bool is_uniform_param_v = false; - -template -inline constexpr bool is_uniform_param_v> = true; - -template -inline constexpr bool is_uniform_param_v> = true; - -template -inline constexpr bool is_per_segment_param_v = false; - -template -inline constexpr bool is_per_segment_param_v> = true; - -template -inline constexpr bool is_per_segment_param_v> = true; - -// Get max value (works for all types inheriting bounds_mixin) -template -inline constexpr auto static_max_value_v = T::static_max_value; - -// Get min value (works for all types inheriting bounds_mixin) -template -inline constexpr auto static_min_value_v = T::static_min_value; - -// Whether a given parameter allows only for a single static value -template -inline constexpr bool has_single_static_value_v = (static_max_value_v == static_min_value_v); - -// Helper that translates a runtime parameter value into a compile-time constant by matching against a list of supported -// options. +// ===================================================================== +// Discrete dispatch +// ===================================================================== + +//! @brief Translates a runtime parameter value into a compile-time constant by matching +//! against a list of supported options. +//! +//! @param[in] val Runtime value to match. +//! @param[in] __supported_options Supported values for the parameter. +//! @param[in] f Functor invoked with the matched compile-time constant. +//! @return `true` if the value matches one of the supported options. template -_CCCL_HOST_DEVICE bool dispatch_impl(T val, supported_options, Functor&& f) +[[nodiscard]] _CCCL_HOST_DEVICE bool +dispatch_impl(T val, [[maybe_unused]] supported_options __supported_options, Functor&& f) { - // Fold expression over the supported options. - // This generates code equivalent to: - // if (val == Opt1) f(integral_constant); - // else if (val == Opt2) f(integral_constant); - // ... const bool match_found = ((val == Opts ? (f(::cuda::std::integral_constant{}), true) : false) || ...); - - // Optional: Handling cases where the runtime value was not in the supported - // list. In a release build, we assume the user respected the contract. _CCCL_ASSERT(match_found, "The given runtime parameter value is not in the supported list"); return match_found; } -// Dispatcher that matches a runtime parameter value against a list of supported options and invokes a functor with the -// matched option as a compile-time constant. +//! @brief Dispatcher that resolves a per-segment discrete parameter to a compile-time constant +//! and invokes a functor with the matched option. +//! +//! @param[in] param Discrete parameter to resolve. +//! @param[in] segment_id Segment index to read from `param`. +//! @param[in] f Functor invoked with the matched compile-time constant. +//! @return `true` if the parameter value matches one of its supported options. template -_CCCL_HOST_DEVICE bool dispatch_discrete(ParamT param, SegmentIndexT segment_id, Functor&& f) +[[nodiscard]] _CCCL_HOST_DEVICE bool dispatch_discrete(ParamT param, SegmentIndexT segment_id, Functor&& f) { using supported_list = typename ParamT::supported_options_t; auto param_value = param.get_param(segment_id); - return dispatch_impl(param_value, supported_list{}, ::cuda::std::forward(f)); + return CUB_NS_QUALIFIER::detail::params::dispatch_impl( + param_value, supported_list{}, ::cuda::std::forward(f)); } } // namespace detail::params diff --git a/cub/cub/device/dispatch/dispatch_batched_topk.cuh b/cub/cub/device/dispatch/dispatch_batched_topk.cuh index 4e9714b3ceb..3d8293742eb 100644 --- a/cub/cub/device/dispatch/dispatch_batched_topk.cuh +++ b/cub/cub/device/dispatch/dispatch_batched_topk.cuh @@ -31,11 +31,13 @@ #include +#include #include #include #include #include #include +#include #include #include @@ -44,103 +46,32 @@ CUB_NAMESPACE_BEGIN namespace detail::batched_topk { // ----------------------------------------------------------------------------- -// Segmented Top-K-Specific Parameter Types +// Internal: wrap user-facing select direction into discrete param for dispatch // ----------------------------------------------------------------------------- -// ------------ SELECTION DIRECTION PARAMETER TYPES ------------ - -// Selection direction known at compile time, same value applies to all segments -template -using select_direction_static = params::uniform_discrete_param; - -// Selection direction is a runtime value, same value applies to all segments -using select_direction_uniform = - params::uniform_discrete_param; - -// Per-segment selection direction via iterator -template -using select_direction_per_segment = - params::per_segment_discrete_param; - -// ------------ SEGMENT SIZE PARAMETER TYPES ------------ - -// Segment size known at compile time, same value applies to all segments -template <::cuda::std::int64_t SegmentSize> -using segment_size_static = params::static_constant_param<::cuda::std::int64_t, SegmentSize>; - -// Segment size is a runtime value, same value applies to all segments -template <::cuda::std::int64_t MinSegmentSize = 0, - ::cuda::std::int64_t MaxSegmentSize = ::cuda::std::numeric_limits<::cuda::std::int64_t>::max()> -using segment_size_uniform = params::uniform_param<::cuda::std::int64_t, MinSegmentSize, MaxSegmentSize>; - -// Segment size via iterator -template ::max()> -using segment_size_per_segment = - params::per_segment_param; - -// ------------ K PARAMETER TYPES ------------ - -// K known at compile time, same value applies to all segments -template <::cuda::std::int64_t K> -using k_static = params::static_constant_param<::cuda::std::int64_t, K>; - -// K is a runtime value, same value applies to all segments -template <::cuda::std::int64_t MinK = 1, - ::cuda::std::int64_t MaxK = ::cuda::std::numeric_limits<::cuda::std::int64_t>::max()> -using k_uniform = params::uniform_param<::cuda::std::int64_t, MinK, MaxK>; - -// K via iterator -template ::max()> -using k_per_segment = params::per_segment_param; - -// ------------ TOTAL NUMBER OF SEGMENTS ------------ -// Number of segments known at compile time -template <::cuda::std::int64_t StaticNumSegments> -using num_segments_static = params::static_constant_param<::cuda::std::int64_t, StaticNumSegments>; - -// Number of segments is a runtime value -template <::cuda::std::int64_t MinNumSegments = 1, - ::cuda::std::int64_t MaxNumSegments = ::cuda::std::numeric_limits<::cuda::std::int64_t>::max()> -using num_segments_uniform = params::uniform_param<::cuda::std::int64_t, MinNumSegments, MaxNumSegments>; - -// Number of segments via iterator -template ::max()> -using num_segments_indirect = - params::per_segment_param; - -// ------------ TOTAL NUMBER OF ITEMS PARAMETER TYPES ------------ - -// Number of items guarantee -template <::cuda::std::int64_t MinNumItems = 1, - ::cuda::std::int64_t MaxNumItems = ::cuda::std::numeric_limits<::cuda::std::int64_t>::max()> -struct total_num_items_guarantee +// Uniform (compile-time): __constant -> single-option uniform_discrete_param. +template +[[nodiscard]] _CCCL_HOST_DEVICE auto wrap_select_direction(::cuda::__argument::__constant) { - using value_type = ::cuda::std::int64_t; - static constexpr value_type static_min_num_items = MinNumItems; - static constexpr value_type static_max_num_items = MaxNumItems; - - value_type min_num_items = MinNumItems; - value_type max_num_items = MaxNumItems; - - // Create default ctor, 1 param ctor taking min, 2 param ctor taking min/max - total_num_items_guarantee() = default; + return params::uniform_discrete_param{Dir}; +} - _CCCL_HOST_DEVICE total_num_items_guarantee(value_type num_items) - : min_num_items(num_items) - , max_num_items(num_items) - {} +// Uniform: single enum value → uniform_discrete_param +[[nodiscard]] _CCCL_HOST_DEVICE inline auto wrap_select_direction(detail::topk::select dir) +{ + return params::uniform_discrete_param{ + dir}; +} - _CCCL_HOST_DEVICE total_num_items_guarantee(value_type min_items, value_type max_items) - : min_num_items(min_items) - , max_num_items(max_items) - {} -}; +// Per-segment: iterator of enums → per_segment_discrete_param +_CCCL_TEMPLATE(typename IteratorT) +_CCCL_REQUIRES((!::cuda::std::is_same_v<::cuda::std::remove_cv_t, detail::topk::select>) ) +[[nodiscard]] _CCCL_HOST_DEVICE auto wrap_select_direction(IteratorT iter) +{ + return params:: + per_segment_discrete_param{ + iter}; +} // ----------------------------------------------------------------------------- // Helper: turn a segment ID into the number of large-segment-agent tiles needed @@ -158,7 +89,7 @@ struct segment_size_to_tile_count_op _CCCL_HOST_DEVICE _CCCL_FORCEINLINE constexpr TotalNumItemsValueType operator()(SegmentIndexT segment_id) const { return static_cast( - ::cuda::ceil_div(segment_sizes.get_param(segment_id), large_segment_agent_tile_size)); + ::cuda::ceil_div(params::get_param(segment_sizes, segment_id), large_segment_agent_tile_size)); } }; @@ -189,13 +120,13 @@ template >, it_value_t>, ::cuda::std::int64_t, - params::static_max_value_v>> + ::cuda::__argument::__traits::max>> #if _CCCL_HAS_CONCEPTS() requires batched_topk_policy_selector #endif // _CCCL_HAS_CONCEPTS() @@ -208,13 +139,18 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( ValueOutputItItT d_value_segments_out_it, SegmentSizeParameterT segment_sizes, KParameterT k, - SelectDirectionParameterT select_directions, + SelectDirectionT select_direction, NumSegmentsParameterT num_segments, [[maybe_unused]] TotalNumItemsGuaranteeT total_num_items_guarantee, cudaStream_t stream = nullptr, [[maybe_unused]] PolicySelector policy_selector = {}) { - using large_segment_tile_offset_t = typename TotalNumItemsGuaranteeT::value_type; + using large_segment_tile_offset_t = typename ::cuda::__argument::__traits::element_type; + + // Wrap the raw enum into the internal discrete param type + auto select_directions = wrap_select_direction(select_direction); + using SelectDirectionParameterT = decltype(select_directions); + // Helper that determines (a) whether there's any one-worker-per-segment policy supporting the range of segment // sizes and k, and (b) if so, which set of one-worker-per-segment policies to use constexpr auto policy = find_smallest_covering_policy< @@ -235,9 +171,9 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( static constexpr int worker_per_segment_tile_size = worker_per_segment_policy.threads_per_block * worker_per_segment_policy.items_per_thread; static constexpr bool any_small_segments = - params::static_min_value_v <= worker_per_segment_tile_size; + ::cuda::__argument::__traits::lowest <= worker_per_segment_tile_size; static constexpr bool only_small_segments = - params::static_max_value_v <= worker_per_segment_tile_size; + ::cuda::__argument::__traits::max <= worker_per_segment_tile_size; // Allocation layout: // only_small_segments: [0] dummy. @@ -247,7 +183,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( static constexpr int allocations_array_size = only_small_segments ? 1 : (any_small_segments ? 3 : 2); size_t allocation_sizes[allocations_array_size] = {1}; - using num_segments_val_t = typename NumSegmentsParameterT::value_type; + using num_segments_val_t = typename ::cuda::__argument::__traits::element_type; using counters_t = batched_topk_counters; using segment_size_scan_offset_t = detail::choose_offset_t; using segment_size_scan_input_op_t = @@ -261,7 +197,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( if constexpr (!only_small_segments) { - const auto num_segments_val = num_segments.get_param(0); + const auto num_segments_val = params::get_param(num_segments, 0); // Scan output allocation_sizes[0] = num_segments_val * sizeof(large_segment_tile_offset_t); if constexpr (any_small_segments) @@ -303,7 +239,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( // TODO (elstehle): support number of segments provided by device-accessible iterator // Only uniform number of segments are supported (i.e., we need to resolve the number of segments on the host) - static_assert(!params::is_per_segment_param_v, + static_assert(::cuda::__argument::__traits::is_single_value, "Only uniform segment sizes are currently supported."); if constexpr (any_small_segments) @@ -317,7 +253,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( return error; } } - const int grid_dim = static_cast(num_segments.get_param(0)); + const int grid_dim = static_cast(params::get_param(num_segments, 0)); constexpr int block_dim = worker_per_segment_policy.threads_per_block; if (const auto error = CubDebug( THRUST_NS_QUALIFIER::cuda_cub::detail::triple_chevron(grid_dim, block_dim, 0, stream) @@ -361,7 +297,7 @@ CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t dispatch( static_cast(allocations[0]), ::cuda::std::plus<>{}, detail::InputValue(large_segment_tile_offset_t{0}), - static_cast(num_segments.get_param(0)), + static_cast(params::get_param(num_segments, 0)), stream))) { return error; @@ -405,7 +341,7 @@ template >, it_value_t>, ::cuda::std::int64_t, - params::static_max_value_v>; + ::cuda::__argument::__traits::max>; return detail::dispatch_with_env_and_tuning( env, [&](auto policy_selector, void* d_temp_storage, size_t& temp_storage_bytes, cudaStream_t stream) { return dispatch( diff --git a/cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh b/cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh index 223b63b93c1..63ad24aaeac 100644 --- a/cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh +++ b/cub/cub/device/dispatch/kernels/kernel_batched_topk.cuh @@ -17,10 +17,10 @@ #endif // no system header #include -#include #include #include +#include #include CUB_NAMESPACE_BEGIN @@ -39,7 +39,7 @@ private: worker_policy worker_per_segment_policy; multi_worker_policy multi_worker_per_segment_policy; }; - static constexpr ::cuda::std::int64_t max_segment_size = params::static_max_value_v; + static constexpr ::cuda::std::int64_t max_segment_size = ::cuda::__argument::__traits::max; static constexpr batched_topk_policy active_policy = current_policy(); template @@ -133,8 +133,8 @@ __launch_bounds__(int( KParameterT k, SelectDirectionParameterT select_directions, NumSegmentsParameterT num_segments, - batched_topk_counters* d_counters, - typename NumSegmentsParameterT::value_type* d_large_segments_ids, + batched_topk_counters::element_type>* d_counters, + typename ::cuda::__argument::__traits::element_type* d_large_segments_ids, LargeSegmentTileOffsetT* d_large_segments_tile_offsets) { using agent_t = typename find_smallest_covering_policy< @@ -151,7 +151,7 @@ __launch_bounds__(int( LargeSegmentTileOffsetT>::agent_t; // Static Assertions (Constraints) - static_assert(agent_t::tile_size >= params::static_max_value_v, + static_assert(agent_t::tile_size >= ::cuda::__argument::__traits::max, "Block size exceeds maximum segment size supported by SegmentSizeParameterT"); static_assert(sizeof(typename agent_t::TempStorage) <= max_smem_per_block, "Static shared memory per block must not exceed 48KB limit."); diff --git a/cub/test/catch2_test_device_segmented_topk_keys.cu b/cub/test/catch2_test_device_segmented_topk_keys.cu index 164ceb231e6..3d00c1119cc 100644 --- a/cub/test/catch2_test_device_segmented_topk_keys.cu +++ b/cub/test/catch2_test_device_segmented_topk_keys.cu @@ -31,7 +31,7 @@ template CUB_RUNTIME_FUNCTION static cudaError_t dispatch_batched_topk_keys( @@ -41,7 +41,7 @@ CUB_RUNTIME_FUNCTION static cudaError_t dispatch_batched_topk_keys( KeyOutputItItT d_key_segments_out_it, SegmentSizeParamT segment_sizes, KParamT k, - SelectDirectionParamT select_directions, + SelectDirectionT select_direction, NumSegmentsParameterT num_segments, TotalNumItemsGuaranteeT total_num_items_guarantee, cudaStream_t stream = nullptr) @@ -56,7 +56,7 @@ CUB_RUNTIME_FUNCTION static cudaError_t dispatch_batched_topk_keys( values_it, segment_sizes, k, - select_directions, + select_direction, num_segments, total_num_items_guarantee, stream); @@ -151,11 +151,11 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small fixed-size segments", batched_topk_keys( d_keys_in, d_keys_out, - cub::detail::batched_topk::segment_size_uniform<1, max_segment_size>{segment_size}, - cub::detail::batched_topk::k_uniform<1, static_max_k>{k}, - cub::detail::batched_topk::select_direction_uniform{direction}, - cub::detail::batched_topk::num_segments_uniform<>{num_segments}, - cub::detail::batched_topk::total_num_items_guarantee{num_segments * segment_size}); + ::cuda::__argument::__immediate{segment_size, ::cuda::__argument::__bounds()}, + ::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds()}, + direction, + ::cuda::__argument::__immediate{num_segments}, + ::cuda::__argument::__immediate{num_segments * segment_size}); // Prepare expected results fixed_size_segmented_sort_keys(expected_keys, num_segments, segment_size, direction); compact_sorted_keys_to_topk(expected_keys, segment_size, k); @@ -248,12 +248,12 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small variable-size segment batched_topk_keys( d_keys_in, d_keys_out, - cub::detail::batched_topk::segment_size_per_segment{ - segment_size_it}, - cub::detail::batched_topk::k_uniform<1, static_max_k>{k}, - cub::detail::batched_topk::select_direction_uniform{direction}, - cub::detail::batched_topk::num_segments_uniform<>{num_segments}, - cub::detail::batched_topk::total_num_items_guarantee{num_items}); + ::cuda::__argument::__immediate_sequence{ + segment_size_it, ::cuda::__argument::__bounds()}, + ::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds()}, + direction, + ::cuda::__argument::__immediate{num_segments}, + ::cuda::__argument::__immediate{num_items}); // Verify keys are returned correctly: sort each segment of the expected input, then compact the top-k segmented_sort_keys(expected_keys, num_segments, segment_offsets.cbegin(), segment_offsets.cbegin() + 1, direction); @@ -269,10 +269,10 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Keys work with small variable-size segment // Regression test: top-k must preserve -0.0f in the output (not normalize to +0.0f). C2H_TEST("DeviceBatchedTopK::MinKeys preserves -0.0f in output", "[keys][segmented][topk][device][float]") { - constexpr cuda::std::int64_t segment_size = 8; - constexpr cuda::std::int64_t k = 5; - constexpr cuda::std::int64_t num_segments = 1; - [[maybe_unused]] constexpr cuda::std::size_t max_segment_size = 64; // msvc warns, only used in nttp + constexpr cuda::std::int64_t segment_size = 8; + constexpr cuda::std::int64_t k = 5; + constexpr cuda::std::int64_t num_segments = 1; + [[maybe_unused]] constexpr cuda::std::int64_t max_segment_size = 64; // msvc warns, only used in nttp // Input: one segment containing -0.0f and +0.0f; top-5 min should include both zeros. c2h::device_vector d_keys_in{3.0f, -0.0f, 1.0f, 2.0f, 0.0f, -1.0f, 4.0f, 5.0f}; @@ -286,11 +286,12 @@ C2H_TEST("DeviceBatchedTopK::MinKeys preserves -0.0f in output", "[keys][segment batched_topk_keys( d_keys_in_it, d_keys_out_it, - cub::detail::batched_topk::segment_size_uniform<1, max_segment_size>{segment_size}, - cub::detail::batched_topk::k_uniform<1, static_cast(k)>{k}, - cub::detail::batched_topk::select_direction_uniform{cub::detail::topk::select::min}, - cub::detail::batched_topk::num_segments_uniform<>{num_segments}, - cub::detail::batched_topk::total_num_items_guarantee{num_segments * segment_size}); + ::cuda::__argument::__immediate{ + segment_size, ::cuda::__argument::__bounds()}, + ::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds()}, + cub::detail::topk::select::min, + ::cuda::__argument::__immediate{num_segments}, + ::cuda::__argument::__immediate{num_segments * segment_size}); const int num_minus_zero = static_cast(thrust::count_if(d_keys_out.begin(), d_keys_out.end(), is_minus_zero{})); REQUIRE(num_minus_zero >= 1); diff --git a/cub/test/catch2_test_device_segmented_topk_pairs.cu b/cub/test/catch2_test_device_segmented_topk_pairs.cu index a519cff9959..cc34ceba3c6 100644 --- a/cub/test/catch2_test_device_segmented_topk_pairs.cu +++ b/cub/test/catch2_test_device_segmented_topk_pairs.cu @@ -220,11 +220,11 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small fixed-size segments" d_keys_out, d_values_in, d_values_out, - cub::detail::batched_topk::segment_size_uniform<1, max_segment_size>{segment_size}, - cub::detail::batched_topk::k_uniform<1, static_max_k>{k}, - cub::detail::batched_topk::select_direction_uniform{direction}, - cub::detail::batched_topk::num_segments_uniform<>{num_segments}, - cub::detail::batched_topk::total_num_items_guarantee{num_segments * segment_size}); + ::cuda::__argument::__immediate{segment_size, ::cuda::__argument::__bounds()}, + ::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds()}, + direction, + ::cuda::__argument::__immediate{num_segments}, + ::cuda::__argument::__immediate{num_segments * segment_size}); // Verification: // - We verify correct top-k selection through the keys @@ -340,12 +340,12 @@ C2H_TEST("DeviceBatchedTopK::{Min,Max}Pairs work with small variable-size segmen d_keys_out, d_values_in, d_values_out, - cub::detail::batched_topk::segment_size_per_segment{ - segment_size_it}, - cub::detail::batched_topk::k_uniform<1, static_max_k>{k}, - cub::detail::batched_topk::select_direction_uniform{direction}, - cub::detail::batched_topk::num_segments_uniform<>{num_segments}, - cub::detail::batched_topk::total_num_items_guarantee{num_items}); + ::cuda::__argument::__immediate_sequence{ + segment_size_it, ::cuda::__argument::__bounds()}, + ::cuda::__argument::__immediate{k, ::cuda::__argument::__bounds()}, + direction, + ::cuda::__argument::__immediate{num_segments}, + ::cuda::__argument::__immediate{num_items}); // Verification: // - We verify correct top-k selection through the keys