Skip to content

Fuse pointwise across split slices#4733

Open
pfultz2 wants to merge 4 commits intodevelopfrom
fuse-pointwise-split-slices
Open

Fuse pointwise across split slices#4733
pfultz2 wants to merge 4 commits intodevelopfrom
fuse-pointwise-split-slices

Conversation

@pfultz2
Copy link
Copy Markdown
Collaborator

@pfultz2 pfultz2 commented Apr 1, 2026

Motivation

Technical Details

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

@pfultz2 pfultz2 requested a review from causten as a code owner April 1, 2026 22:17
Copilot AI review requested due to automatic review settings April 1, 2026 22:17
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends pointwise fusion by splitting a single pointwise op that feeds multiple non-overlapping slice consumers into multiple per-slice pointwise ops, enabling downstream pointwise fusion within each sliced path.

Changes:

  • Add a new split_pointwise_through_slices() pre-pass step in fuse_pointwise to push pointwise computation through split slices.
  • Update the fuse_pointwise fixpoint loop to account for the new splitting transform.
  • Add multiple unit tests covering 2-way/3-way slice splits and “do not split” scenarios (overlap, not all slices, no downstream pointwise), plus a larger 3-input add pattern.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
src/fuse_pointwise.cpp Introduces slice-driven pointwise splitting and integrates it into the fusion loop.
test/fuse_pointwise.cpp Adds tests validating the new split-through-slices behavior and non-applicable cases.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +276 to +279
bool changed = false;
auto& m = mpm.get_module();
std::size_t idx = 0;

Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

idx is reset to 0 on every call, but split_pointwise_through_slices() is invoked inside the 0..7 fixpoint loop in fuse_pointwise::apply(). If the pass ends up splitting the same source pointwise module name in a later iteration (e.g., after additional pointwise merges), create_module(pm_name + ":split" + std::to_string(idx++), ...) can attempt to recreate an already-existing module name and hit program::create_module's assert(not contains(modules, name)). Use a counter that is unique across invocations/iterations (or derive uniqueness from the slice/pointwise instruction) so module names cannot collide.

Copilot uses AI. Check for mistakes.
Comment on lines +335 to +343
for(const auto& slice_ins : outputs)
{
auto slice_op = slice_ins->get_operator();

std::vector<instruction_ref> sliced_inputs;
sliced_inputs.reserve(inputs.size());
transform(inputs, std::back_inserter(sliced_inputs), [&](instruction_ref input) {
return m.insert_instruction(slice_ins, slice_op, input);
});
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This transformation assumes each slice consumer can be replicated onto every pointwise input by calling insert_instruction(slice_ins, slice_op, input). That only works if the slice instruction is the 1-input form with axes/starts/ends fully specified as operator attributes; if the original slice uses 2–4 inputs (starts/ends/axes provided as inputs) this will produce an invalid slice and can throw during shape computation. Consider explicitly requiring slice_ins->inputs().size() == 1 (and/or copying all slice inputs) before applying the split.

Copilot uses AI. Check for mistakes.
Comment on lines +298 to +313
slice_vals[output] = output->get_operator().to_value();

// All slices must be on the same single axis
auto axes = slice_vals[outputs.front()]["axes"].to_vector<int64_t>();
if(axes.size() != 1)
continue;
if(not all_of(outputs, [&](instruction_ref output) {
return slice_vals[output]["axes"].to_vector<int64_t>() == axes;
}))
continue;

auto get_starts = [&](instruction_ref s) {
return slice_vals[s]["starts"].to_vector<int64_t>()[0];
};
auto get_ends = [&](instruction_ref s) {
return slice_vals[s]["ends"].to_vector<int64_t>()[0];
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slice_vals[output]["axes"] / ["starts"] / ["ends"] uses value::operator[], which inserts a null value when the key is missing; the subsequent to_vector<>() will then throw (null isn't an array/object). This makes the pass brittle if any slice is missing these attributes (or if normalization rewrites them). Prefer reading from slice_ins->normalized_operator().to_value() and using contains() / get(...) (or at(...) with a prior check) to avoid accidental insertion and exceptions.

Suggested change
slice_vals[output] = output->get_operator().to_value();
// All slices must be on the same single axis
auto axes = slice_vals[outputs.front()]["axes"].to_vector<int64_t>();
if(axes.size() != 1)
continue;
if(not all_of(outputs, [&](instruction_ref output) {
return slice_vals[output]["axes"].to_vector<int64_t>() == axes;
}))
continue;
auto get_starts = [&](instruction_ref s) {
return slice_vals[s]["starts"].to_vector<int64_t>()[0];
};
auto get_ends = [&](instruction_ref s) {
return slice_vals[s]["ends"].to_vector<int64_t>()[0];
slice_vals.emplace(output, output->normalized_operator().to_value());
// All slices must be on the same single axis
auto& first_slice_val = slice_vals.at(outputs.front());
if(not first_slice_val.contains("axes"))
continue;
auto axes = first_slice_val.at("axes").to_vector<int64_t>();
if(axes.size() != 1)
continue;
// All slices must have matching axes and valid starts/ends attributes
bool valid_slices = all_of(outputs, [&](instruction_ref output) {
auto& sv = slice_vals.at(output);
if(not sv.contains("axes") or not sv.contains("starts") or not sv.contains("ends"))
return false;
if(sv.at("axes").to_vector<int64_t>() != axes)
return false;
auto starts_vec = sv.at("starts").to_vector<int64_t>();
auto ends_vec = sv.at("ends").to_vector<int64_t>();
return (not starts_vec.empty()) and (not ends_vec.empty());
});
if(not valid_slices)
continue;
auto get_starts = [&](instruction_ref s) {
auto& sv = slice_vals.at(s);
auto starts_vec = sv.at("starts").to_vector<int64_t>();
return starts_vec[0];
};
auto get_ends = [&](instruction_ref s) {
auto& sv = slice_vals.at(s);
auto ends_vec = sv.at("ends").to_vector<int64_t>();
return ends_vec[0];

Copilot uses AI. Check for mistakes.
Comment on lines +309 to +322
auto get_starts = [&](instruction_ref s) {
return slice_vals[s]["starts"].to_vector<int64_t>()[0];
};
auto get_ends = [&](instruction_ref s) {
return slice_vals[s]["ends"].to_vector<int64_t>()[0];
};

// Sort slices by start position and check for no overlap
std::sort(outputs.begin(), outputs.end(), by(std::less<>{}, get_starts));
if(std::adjacent_find(
outputs.begin(), outputs.end(), [&](instruction_ref a, instruction_ref b) {
return get_starts(b) < get_ends(a);
}) != outputs.end())
continue;
Copy link

Copilot AI Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The overlap check/sort uses raw starts/ends from get_operator().to_value(). For slice, these indices can be subject to normalization rules (e.g., negative indices or large sentinel ends values clipped to the axis length). Using unnormalized values can mis-detect overlap/order and either skip valid opportunities or apply the split when slices actually overlap after normalization. Use the slice instruction’s normalized operator values (or otherwise normalize indices against the input shape) for starts/ends comparisons.

Copilot uses AI. Check for mistakes.
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 1, 2026

Codecov Report

❌ Patch coverage is 96.22642% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/fuse_pointwise.cpp 96.23% 2 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #4733      +/-   ##
===========================================
+ Coverage    92.29%   92.29%   +0.01%     
===========================================
  Files          580      580              
  Lines        28688    28740      +52     
===========================================
+ Hits         26475    26525      +50     
- Misses        2213     2215       +2     
Files with missing lines Coverage Δ
src/fuse_pointwise.cpp 97.58% <96.23%> (-0.38%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants