Adds tests for segment-specific-k-values#9311
Conversation
OverviewThis PR extends test coverage for DeviceBatchedTopK to support per-segment-specific k values (k that can vary per segment) and fixes a one-past-the-end/read-bounds issue encountered when computing compacted output sizes for exclusive_scan. Tests previously assumed a uniform k for all segments; this PR adds coverage and helpers for supplying k as a per-segment sequence (immediate or deferred) and validates correctness for both keys-only and key/value (pairs) variants across fixed-size and variable-size segments. ChangesTest Cases
Test Utilities / Helpers
Benchmarks
Internal params change
Fixes / Notes
Scope / Impact
suggestion: WalkthroughAdds per-segment-k compaction support and device tests for BatchedTopK (keys and pairs), updates get_output_size_op to accept num_segments, normalizes cuda::args usage (immediate/deferred/constant) across tests and benchmarks, and adjusts segmented parameter overloads. ChangesPer-segment k support in DeviceBatchedTopK tests
Assessment against linked issues
Possibly related PRs
Suggested reviewers
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
cub/test/catch2_test_device_topk_common.cuh (2)
66-70:⚠️ Potential issue | 🔴 Critical | ⚡ Quick wincritical: Out-of-bounds access when segment_id equals num_segments.
When this functor is invoked via exclusive_scan with
num_segments + 1elements (line 310), the last invocation passessegment_id = num_segments. Line 68 then accessesoffset_it[segment_id + 1] = offset_it[num_segments + 1], butoffset_itis constructed fromd_offsets.cbegin()whered_offsetshas sizenum_segments + 1(valid indices 0..num_segments). Accessing indexnum_segments + 1is undefined behavior.Fix by either:
- Making the functor return 0 when
segment_id >= num_segments, or- Scanning only
num_segmentselements and computing the final offset separately, or- Appending a sentinel element to ensure offsets has
num_segments + 2elements.
305-310:⚠️ Potential issue | 🔴 Critical | ⚡ Quick wincritical: Scanning num_segments+1 elements triggers out-of-bounds access in get_output_size_op.
The scan reads
num_segments + 1elements fromcopy_sizes_it, which callsget_output_size_op::operator()withsegment_idvalues 0..num_segments. Forsegment_id = num_segments, the functor accessesoffset_it[num_segments + 1], butd_offsets(line 298) has sizenum_segments + 1with valid indices 0..num_segments. This is undefined behavior.Recommended fix: scan only
num_segmentselements to compute the firstnum_segmentsoutput sizes, then setd_output_offsets[num_segments]separately via partial sum, or modifyget_output_size_opto guard againstsegment_id == num_segments.cub/test/catch2_test_device_segmented_topk_keys.cu (1)
230-235:⚠️ Potential issue | 🔴 Critical | ⚡ Quick wincritical: Existing test also has the out-of-bounds bug.
This existing variable-size test scans
num_segments + 1elements, which callsget_output_size_opwithsegment_id = num_segments, triggering out-of-bounds access atsegment_offsets[num_segments + 1]. The bug exists in the helper function (catch2_test_device_topk_common.cuh:66-70) and affects both the existing and new tests.cub/test/catch2_test_device_segmented_topk_pairs.cu (1)
309-316:⚠️ Potential issue | 🔴 Critical | ⚡ Quick wincritical: Existing test also has the out-of-bounds bug.
This existing variable-size pairs test scans
num_segments + 1elements, which callsget_output_size_opwithsegment_id = num_segments, triggering out-of-bounds access atsegment_offsets[num_segments + 1]. The bug exists in the helper function (catch2_test_device_topk_common.cuh:66-70) and affects both the existing and new tests.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 90fb2be4-ca19-4689-b154-741a3a53460c
📒 Files selected for processing (3)
cub/test/catch2_test_device_segmented_topk_keys.cucub/test/catch2_test_device_segmented_topk_pairs.cucub/test/catch2_test_device_topk_common.cuh
This comment has been minimized.
This comment has been minimized.
NaderAlAwar
left a comment
There was a problem hiding this comment.
Important: the tests take a very long time to build. The topk_keys file took around 430 seconds. I wonder if we can improve that so that not every combination of dtype/segment_size/k/direction is tested
79fa4d2 to
cbe943b
Compare
Thanks for taking this into account! This is definitely beyond what we want to have. The culprit is the I have now split the TUs by key type, which helps bring down the time on the critical path. Also, replacing DeviceSegmentedSort with DeviceSegmentedRadixSort is another promising path that I will follow-up on in a subsequent PR. |
There was a problem hiding this comment.
Actionable comments posted: 2
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 20992622-6243-4b29-8249-9a7bd64f1487
📒 Files selected for processing (6)
cub/benchmarks/bench/segmented_topk/variable/indexed.cucub/benchmarks/bench/segmented_topk/variable/keys.cucub/cub/detail/segmented_params.cuhcub/test/catch2_test_device_segmented_topk_keys.cucub/test/catch2_test_device_segmented_topk_pairs.cucub/test/catch2_test_device_topk_common.cuh
🚧 Files skipped from review as they are similar to previous changes (2)
- cub/test/catch2_test_device_topk_common.cuh
- cub/test/catch2_test_device_segmented_topk_keys.cu
🥳 CI Workflow Results🟩 Finished in 1h 25m: Pass: 100%/287 | Total: 2d 21h | Max: 56m 37s | Hits: 85%/228240See results here. |
Description
Closes #7914
Currently tests are limited to a
kthat is uniform across all segments. We now extend tests to support segment-specifickvalues, i.e.,kvalues that can vary across segments.