Skip to content

Adds tests for segment-specific-k-values#9311

Merged
elstehle merged 7 commits into
NVIDIA:mainfrom
elstehle:enh/topk-per-segment-k-tests
Jun 11, 2026
Merged

Adds tests for segment-specific-k-values#9311
elstehle merged 7 commits into
NVIDIA:mainfrom
elstehle:enh/topk-per-segment-k-tests

Conversation

@elstehle

@elstehle elstehle commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Description

Closes #7914

Currently tests are limited to a k that is uniform across all segments. We now extend tests to support segment-specific k values, i.e., k values that can vary across segments.

@elstehle elstehle requested a review from a team as a code owner June 9, 2026 06:27
@elstehle elstehle requested a review from gevtushenko June 9, 2026 06:27
@github-project-automation github-project-automation Bot moved this to Todo in CCCL Jun 9, 2026
@cccl-authenticator-app cccl-authenticator-app Bot moved this from Todo to In Review in CCCL Jun 9, 2026
@coderabbitai

coderabbitai Bot commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

Note: CodeRabbit is enabled on this repository as a convenience for maintainers
and contributors. Use your best judgment when considering its review comments and
suggestions — a suggested change may be inadequate, unnecessary, or safe to ignore.
Contributors are not expected to address every comment. Human reviews are what
ultimately matter for merging.

Overview

This PR extends test coverage for DeviceBatchedTopK to support per-segment-specific k values (k that can vary per segment) and fixes a one-past-the-end/read-bounds issue encountered when computing compacted output sizes for exclusive_scan. Tests previously assumed a uniform k for all segments; this PR adds coverage and helpers for supplying k as a per-segment sequence (immediate or deferred) and validates correctness for both keys-only and key/value (pairs) variants across fixed-size and variable-size segments.

Changes

Test Cases

  • cub/test/catch2_test_device_segmented_topk_keys.cu

    • Adjusted dispatch call to pass namespace-qualified direction parameter.
    • Reworked key type selection to use conditional branches.
    • Updated existing "small variable-size segments" test to pass num_segments to the output-size op and switched some segment_size/k argument forms from immediate to deferred sequences where appropriate.
    • Added two device-side Catch2 tests validating per-segment k for DeviceBatchedTopK keys:
      • Fixed-size segments + per-segment k (immediate sequence): computes per-segment compacted offsets using per-segment k, runs batched_topk_keys, and verifies expected top-k per segment.
      • Variable-size segments + per-segment k (immediate sequence): builds variable offsets, computes guarded compacted offsets, runs batched_topk_keys with permutation iterators, and verifies per-segment results (sorting compacted outputs before comparison).
  • cub/test/catch2_test_device_segmented_topk_pairs.cu

    • Updated existing small-variable-segment test to pass num_segments to get_output_size_op and changed segment_sizes usage to deferred_sequence.
    • Added two device-side Catch2 tests validating per-segment k for DeviceBatchedTopK pairs (keys+values):
      • Fixed-size segments + per-segment k: uses per-segment k device vector (deferred sequence), computes compacted offsets, runs batched_topk_pairs with strided iterators, and verifies key/value association, uniqueness, and expected top-k keys.
      • Variable-size segments + per-segment k: uses random segment offsets and per-segment k (deferred sequence), computes guarded compacted offsets, runs batched_topk_pairs with permutation iterators, and verifies key/value association and expected top-k per segment (sorting compacted outputs before equality comparison).
    • Small namespace/qualification formatting fixes in call sites.

Test Utilities / Helpers

  • cub/test/catch2_test_device_topk_common.cuh
    • Added overload of compact_to_topk_batched that accepts a per-segment k iterator (non-integral KSizesItT via SFINAE), computes per-segment copy sizes using a get_output_size_op that now accepts num_segments and returns 0 for the exclusive-scan’s extra invocation to avoid out-of-bounds reads, performs exclusive scan to build destination offsets, and runs DeviceCopy::Batched to compact per-segment top-k.
    • Converted existing uniform-k overload into a delegating wrapper that forwards to the iterator-based overload using cuda::constant_iterator.
    • Added inclusion of <cuda/std/type_traits> to support SFINAE.

Benchmarks

  • cub/benchmarks/bench/segmented_topk/variable/indexed.cu
  • cub/benchmarks/bench/segmented_topk/variable/keys.cu
    • Updated construction of kernel launch argument objects to consistently use cuda::args::immediate / cuda::args::deferred_sequence (dropped leading :: qualification) and switched segment_sizes parameter to deferred_sequence with bounds wrapper.

Internal params change

  • cub/cub/detail/segmented_params.cuh
    • Modified availability of get_param overloads for cuda::args wrappers: moved/provided the get_param overload for ::cuda::args::immediate at this location and removed the overloads handling ::cuda::args::__constant_sequence and ::cuda::args::__immediate_sequence. (This affects template overload resolution for segmented parameter access.)

Fixes / Notes

  • Fixes a one-past-the-end/read-bounds issue in get_output_size_op used to compute compacted output offsets for exclusive_scan by adding a num_segments guard and returning 0 for the scan’s extra element.
  • Adds comprehensive tests for per-segment k handling in both keys-only and pairs variants across fixed-size and variable-size segments, exercising immediate and deferred sequence argument forms.
  • No changes to exported/public API declarations.

Scope / Impact

  • Files touched include tests (keys + pairs + common helpers), benchmarks, and an internal segmented_params header. Changes are primarily test additions and test-helper logic, with small benchmark and internal template overload adjustments that may affect template resolution for cuda::args wrappers.
  • Estimated review effort: High (new tests are straightforward but the header/template changes and the segmented_params overload adjustments require careful review).

suggestion:

Walkthrough

Adds per-segment-k compaction support and device tests for BatchedTopK (keys and pairs), updates get_output_size_op to accept num_segments, normalizes cuda::args usage (immediate/deferred/constant) across tests and benchmarks, and adjusts segmented parameter overloads.

Changes

Per-segment k support in DeviceBatchedTopK tests

Layer / File(s) Summary
Per-segment k iterator support in compact_to_topk_batched
cub/test/catch2_test_device_topk_common.cuh
New overload of compact_to_topk_batched accepts per-segment k via iterator and uses get_output_size_op(..., num_segments) to compute per-segment copy sizes; uniform-k overload delegates via cuda::constant_iterator.
Batched top-k keys tests with per-segment k
cub/test/catch2_test_device_segmented_topk_keys.cu
Adds fixed-size and variable-size segmented tests passing per-segment segment_k to batched_topk_keys, computes compacted offsets via get_output_size_op + exclusive scan, switches segment-size args to cuda::args::deferred_sequence where needed, and updates existing tests to pass num_segments to output-size computation.
Batched top-k pairs tests with per-segment k
cub/test/catch2_test_device_segmented_topk_pairs.cu
Adds fixed-size and variable-size segmented key/value tests mirroring keys tests: per-segment segment_k, compacted offsets, batched_topk_pairs calls with deferred/immediate sequences as appropriate, and verification of key/value association, uniqueness, and correctness.
cuda::args normalization and benchmarks
cub/benchmarks/bench/segmented_topk/variable/*, cub/cub/detail/segmented_params.cuh
Normalize cuda::args::* namespace usage at call sites, change benchmark segment-size args to cuda::args::deferred_sequence, and adjust segmented params overload set (remove old sequence-wrapper overloads, provide immediate overload).

Assessment against linked issues

Objective Addressed Explanation
Add per-segment k test support for fixed-size segments [#7914]
Add per-segment k test support for variable-size segments [#7914]
Extend tests to both keys and pairs variants [#7914]

Possibly related PRs

  • NVIDIA/cccl#9251: related refactor for making cuda::args namespace and wrappers public and adjusting call-site argument construction.

Suggested reviewers

  • gevtushenko
  • davebayer
  • fbusato

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
cub/test/catch2_test_device_topk_common.cuh (2)

66-70: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

critical: Out-of-bounds access when segment_id equals num_segments.

When this functor is invoked via exclusive_scan with num_segments + 1 elements (line 310), the last invocation passes segment_id = num_segments. Line 68 then accesses offset_it[segment_id + 1] = offset_it[num_segments + 1], but offset_it is constructed from d_offsets.cbegin() where d_offsets has size num_segments + 1 (valid indices 0..num_segments). Accessing index num_segments + 1 is undefined behavior.

Fix by either:

  1. Making the functor return 0 when segment_id >= num_segments, or
  2. Scanning only num_segments elements and computing the final offset separately, or
  3. Appending a sentinel element to ensure offsets has num_segments + 2 elements.

305-310: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

critical: Scanning num_segments+1 elements triggers out-of-bounds access in get_output_size_op.

The scan reads num_segments + 1 elements from copy_sizes_it, which calls get_output_size_op::operator() with segment_id values 0..num_segments. For segment_id = num_segments, the functor accesses offset_it[num_segments + 1], but d_offsets (line 298) has size num_segments + 1 with valid indices 0..num_segments. This is undefined behavior.

Recommended fix: scan only num_segments elements to compute the first num_segments output sizes, then set d_output_offsets[num_segments] separately via partial sum, or modify get_output_size_op to guard against segment_id == num_segments.

cub/test/catch2_test_device_segmented_topk_keys.cu (1)

230-235: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

critical: Existing test also has the out-of-bounds bug.

This existing variable-size test scans num_segments + 1 elements, which calls get_output_size_op with segment_id = num_segments, triggering out-of-bounds access at segment_offsets[num_segments + 1]. The bug exists in the helper function (catch2_test_device_topk_common.cuh:66-70) and affects both the existing and new tests.

cub/test/catch2_test_device_segmented_topk_pairs.cu (1)

309-316: ⚠️ Potential issue | 🔴 Critical | ⚡ Quick win

critical: Existing test also has the out-of-bounds bug.

This existing variable-size pairs test scans num_segments + 1 elements, which calls get_output_size_op with segment_id = num_segments, triggering out-of-bounds access at segment_offsets[num_segments + 1]. The bug exists in the helper function (catch2_test_device_topk_common.cuh:66-70) and affects both the existing and new tests.


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 90fb2be4-ca19-4689-b154-741a3a53460c

📥 Commits

Reviewing files that changed from the base of the PR and between 3d2c45a and c6d102b.

📒 Files selected for processing (3)
  • cub/test/catch2_test_device_segmented_topk_keys.cu
  • cub/test/catch2_test_device_segmented_topk_pairs.cu
  • cub/test/catch2_test_device_topk_common.cuh

Comment thread cub/test/catch2_test_device_segmented_topk_keys.cu
Comment thread cub/test/catch2_test_device_segmented_topk_keys.cu
Comment thread cub/test/catch2_test_device_segmented_topk_pairs.cu
Comment thread cub/test/catch2_test_device_segmented_topk_pairs.cu
@github-actions

This comment has been minimized.

@NaderAlAwar NaderAlAwar left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important: the tests take a very long time to build. The topk_keys file took around 430 seconds. I wonder if we can improve that so that not every combination of dtype/segment_size/k/direction is tested

Comment thread cub/test/catch2_test_device_segmented_topk_keys.cu Outdated
@elstehle elstehle force-pushed the enh/topk-per-segment-k-tests branch from 79fa4d2 to cbe943b Compare June 11, 2026 12:17
@elstehle elstehle requested a review from a team as a code owner June 11, 2026 12:17
@elstehle

Copy link
Copy Markdown
Contributor Author

Important: the tests take a very long time to build. The topk_keys file took around 430 seconds. I wonder if we can improve that so that not every combination of dtype/segment_size/k/direction is tested

Thanks for taking this into account! This is definitely beyond what we want to have. The culprit is the DeviceSegmentedSort we use for verification. I already took measures to bring its template instantiations down. Replacing it would half the compilation time but runtimes would get into the 1 min ballpark.

I have now split the TUs by key type, which helps bring down the time on the critical path. Also, replacing DeviceSegmentedSort with DeviceSegmentedRadixSort is another promising path that I will follow-up on in a subsequent PR.

@elstehle elstehle enabled auto-merge (squash) June 11, 2026 12:20

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 20992622-6243-4b29-8249-9a7bd64f1487

📥 Commits

Reviewing files that changed from the base of the PR and between 79fa4d2 and cbe943b.

📒 Files selected for processing (6)
  • cub/benchmarks/bench/segmented_topk/variable/indexed.cu
  • cub/benchmarks/bench/segmented_topk/variable/keys.cu
  • cub/cub/detail/segmented_params.cuh
  • cub/test/catch2_test_device_segmented_topk_keys.cu
  • cub/test/catch2_test_device_segmented_topk_pairs.cu
  • cub/test/catch2_test_device_topk_common.cuh
🚧 Files skipped from review as they are similar to previous changes (2)
  • cub/test/catch2_test_device_topk_common.cuh
  • cub/test/catch2_test_device_segmented_topk_keys.cu

Comment thread cub/test/catch2_test_device_segmented_topk_pairs.cu
Comment thread cub/test/catch2_test_device_segmented_topk_pairs.cu
@github-actions

Copy link
Copy Markdown
Contributor

🥳 CI Workflow Results

🟩 Finished in 1h 25m: Pass: 100%/287 | Total: 2d 21h | Max: 56m 37s | Hits: 85%/228240

See results here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

Add support for per-segment-specific k values to DeviceBatchedTopK

2 participants