Conversation
There was a problem hiding this comment.
Pull request overview
This PR extends pointwise fusion by splitting a single pointwise op that feeds multiple non-overlapping slice consumers into multiple per-slice pointwise ops, enabling downstream pointwise fusion within each sliced path.
Changes:
- Add a new
split_pointwise_through_slices()pre-pass step infuse_pointwiseto push pointwise computation through split slices. - Update the
fuse_pointwisefixpoint loop to account for the new splitting transform. - Add multiple unit tests covering 2-way/3-way slice splits and “do not split” scenarios (overlap, not all slices, no downstream pointwise), plus a larger 3-input add pattern.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| src/fuse_pointwise.cpp | Introduces slice-driven pointwise splitting and integrates it into the fusion loop. |
| test/fuse_pointwise.cpp | Adds tests validating the new split-through-slices behavior and non-applicable cases. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| bool changed = false; | ||
| auto& m = mpm.get_module(); | ||
| std::size_t idx = 0; | ||
|
|
There was a problem hiding this comment.
idx is reset to 0 on every call, but split_pointwise_through_slices() is invoked inside the 0..7 fixpoint loop in fuse_pointwise::apply(). If the pass ends up splitting the same source pointwise module name in a later iteration (e.g., after additional pointwise merges), create_module(pm_name + ":split" + std::to_string(idx++), ...) can attempt to recreate an already-existing module name and hit program::create_module's assert(not contains(modules, name)). Use a counter that is unique across invocations/iterations (or derive uniqueness from the slice/pointwise instruction) so module names cannot collide.
| for(const auto& slice_ins : outputs) | ||
| { | ||
| auto slice_op = slice_ins->get_operator(); | ||
|
|
||
| std::vector<instruction_ref> sliced_inputs; | ||
| sliced_inputs.reserve(inputs.size()); | ||
| transform(inputs, std::back_inserter(sliced_inputs), [&](instruction_ref input) { | ||
| return m.insert_instruction(slice_ins, slice_op, input); | ||
| }); |
There was a problem hiding this comment.
This transformation assumes each slice consumer can be replicated onto every pointwise input by calling insert_instruction(slice_ins, slice_op, input). That only works if the slice instruction is the 1-input form with axes/starts/ends fully specified as operator attributes; if the original slice uses 2–4 inputs (starts/ends/axes provided as inputs) this will produce an invalid slice and can throw during shape computation. Consider explicitly requiring slice_ins->inputs().size() == 1 (and/or copying all slice inputs) before applying the split.
| slice_vals[output] = output->get_operator().to_value(); | ||
|
|
||
| // All slices must be on the same single axis | ||
| auto axes = slice_vals[outputs.front()]["axes"].to_vector<int64_t>(); | ||
| if(axes.size() != 1) | ||
| continue; | ||
| if(not all_of(outputs, [&](instruction_ref output) { | ||
| return slice_vals[output]["axes"].to_vector<int64_t>() == axes; | ||
| })) | ||
| continue; | ||
|
|
||
| auto get_starts = [&](instruction_ref s) { | ||
| return slice_vals[s]["starts"].to_vector<int64_t>()[0]; | ||
| }; | ||
| auto get_ends = [&](instruction_ref s) { | ||
| return slice_vals[s]["ends"].to_vector<int64_t>()[0]; |
There was a problem hiding this comment.
slice_vals[output]["axes"] / ["starts"] / ["ends"] uses value::operator[], which inserts a null value when the key is missing; the subsequent to_vector<>() will then throw (null isn't an array/object). This makes the pass brittle if any slice is missing these attributes (or if normalization rewrites them). Prefer reading from slice_ins->normalized_operator().to_value() and using contains() / get(...) (or at(...) with a prior check) to avoid accidental insertion and exceptions.
| slice_vals[output] = output->get_operator().to_value(); | |
| // All slices must be on the same single axis | |
| auto axes = slice_vals[outputs.front()]["axes"].to_vector<int64_t>(); | |
| if(axes.size() != 1) | |
| continue; | |
| if(not all_of(outputs, [&](instruction_ref output) { | |
| return slice_vals[output]["axes"].to_vector<int64_t>() == axes; | |
| })) | |
| continue; | |
| auto get_starts = [&](instruction_ref s) { | |
| return slice_vals[s]["starts"].to_vector<int64_t>()[0]; | |
| }; | |
| auto get_ends = [&](instruction_ref s) { | |
| return slice_vals[s]["ends"].to_vector<int64_t>()[0]; | |
| slice_vals.emplace(output, output->normalized_operator().to_value()); | |
| // All slices must be on the same single axis | |
| auto& first_slice_val = slice_vals.at(outputs.front()); | |
| if(not first_slice_val.contains("axes")) | |
| continue; | |
| auto axes = first_slice_val.at("axes").to_vector<int64_t>(); | |
| if(axes.size() != 1) | |
| continue; | |
| // All slices must have matching axes and valid starts/ends attributes | |
| bool valid_slices = all_of(outputs, [&](instruction_ref output) { | |
| auto& sv = slice_vals.at(output); | |
| if(not sv.contains("axes") or not sv.contains("starts") or not sv.contains("ends")) | |
| return false; | |
| if(sv.at("axes").to_vector<int64_t>() != axes) | |
| return false; | |
| auto starts_vec = sv.at("starts").to_vector<int64_t>(); | |
| auto ends_vec = sv.at("ends").to_vector<int64_t>(); | |
| return (not starts_vec.empty()) and (not ends_vec.empty()); | |
| }); | |
| if(not valid_slices) | |
| continue; | |
| auto get_starts = [&](instruction_ref s) { | |
| auto& sv = slice_vals.at(s); | |
| auto starts_vec = sv.at("starts").to_vector<int64_t>(); | |
| return starts_vec[0]; | |
| }; | |
| auto get_ends = [&](instruction_ref s) { | |
| auto& sv = slice_vals.at(s); | |
| auto ends_vec = sv.at("ends").to_vector<int64_t>(); | |
| return ends_vec[0]; |
| auto get_starts = [&](instruction_ref s) { | ||
| return slice_vals[s]["starts"].to_vector<int64_t>()[0]; | ||
| }; | ||
| auto get_ends = [&](instruction_ref s) { | ||
| return slice_vals[s]["ends"].to_vector<int64_t>()[0]; | ||
| }; | ||
|
|
||
| // Sort slices by start position and check for no overlap | ||
| std::sort(outputs.begin(), outputs.end(), by(std::less<>{}, get_starts)); | ||
| if(std::adjacent_find( | ||
| outputs.begin(), outputs.end(), [&](instruction_ref a, instruction_ref b) { | ||
| return get_starts(b) < get_ends(a); | ||
| }) != outputs.end()) | ||
| continue; |
There was a problem hiding this comment.
The overlap check/sort uses raw starts/ends from get_operator().to_value(). For slice, these indices can be subject to normalization rules (e.g., negative indices or large sentinel ends values clipped to the axis length). Using unnormalized values can mis-detect overlap/order and either skip valid opportunities or apply the split when slices actually overlap after normalization. Use the slice instruction’s normalized operator values (or otherwise normalize indices against the input shape) for starts/ends comparisons.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## develop #4733 +/- ##
===========================================
+ Coverage 92.29% 92.29% +0.01%
===========================================
Files 580 580
Lines 28688 28740 +52
===========================================
+ Hits 26475 26525 +50
- Misses 2213 2215 +2
🚀 New features to boost your workflow:
|
Motivation
Technical Details
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable