Skip to content

Commit 01bbfbd

Browse files
authored
Fix bug with torch.rand_like compile error (#1289)
Fixes #1208
1 parent d2f7a01 commit 01bbfbd

File tree

3 files changed

+130
-2
lines changed

3 files changed

+130
-2
lines changed

helion/_compiler/aten_lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def _codegen_rng_op(
561561
for i in range(ndim):
562562
# Create the index variable with proper broadcasting
563563
if block_ids[i] is not None:
564-
index_expr = f"indices_{i}"
564+
index_expr = f"indices_{block_ids[i]}"
565565
else:
566566
# For constant dimensions (block_id is None), use tl.arange directly
567567
index_expr = f"tl.arange(0, {dim_names[i]})"

test/test_rng.expected

Lines changed: 74 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,80 @@ def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
7777
# src[test_rng.py:N]: return rand1, rand2, uniform, normal, randn_sum
7878
return (rand1, rand2, uniform, normal, randn_sum)
7979

80+
--- assertExpectedJournal(TestRNG.test_rand_like_nested_tiles_issue_1208)
81+
from __future__ import annotations
82+
83+
import torch
84+
import triton
85+
import triton.language as tl
86+
from helion.runtime import default_launcher as _default_launcher
87+
88+
@triton.jit
89+
def _helion_nested_tiles_rand(q, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, rng_seed_buffer):
90+
# src[test_rng.py:N]: for tile_b, tile_q in hl.tile([B, T]):
91+
num_blocks_0 = tl.cdiv(2, _BLOCK_SIZE_0)
92+
pid_0 = tl.program_id(0) % num_blocks_0
93+
pid_1 = tl.program_id(0) // num_blocks_0
94+
offset_0 = pid_0 * _BLOCK_SIZE_0
95+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
96+
offset_1 = pid_1 * _BLOCK_SIZE_1
97+
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
98+
indices_2 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
99+
# src[test_rng.py:N]: qs = q[tile_b, tile_q, :]
100+
qs = tl.load(q + (indices_0[:, None, None] * 512 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), None)
101+
# src[test_rng.py:N]: for tile_k in hl.tile(T):
102+
# src[test_rng.py:N]: ks = q[tile_b, tile_k, :]
103+
# src[test_rng.py:N]: # logits has shape [tile_b, tile_q, tile_k]
104+
# src[test_rng.py:N-N]: ...
105+
for offset_3 in tl.range(0, 16, _BLOCK_SIZE_3):
106+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
107+
qs_copy = qs
108+
qs_copy_0 = qs_copy
109+
# src[test_rng.py:N]: ks = q[tile_b, tile_k, :]
110+
ks = tl.load(q + (indices_0[:, None, None] * 512 + indices_3[None, :, None] * 32 + indices_2[None, None, :] * 1), None)
111+
# src[test_rng.py:N]: logits = qs @ ks.transpose(-1, -2)
112+
permute = tl.permute(ks, [0, 2, 1])
113+
logits = tl.dot(tl.cast(qs_copy_0, tl.float32), tl.cast(permute, tl.float32), input_precision='tf32', out_dtype=tl.float32)
114+
# src[test_rng.py:N]: rand = torch.rand_like(logits)
115+
rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None, None] * 16 * 16 + indices_1[None, :, None] * 16 + indices_3[None, None, :]).to(tl.float32)
116+
# src[test_rng.py:N]: mask = ((logits + rand) > 0).float()
117+
v_0 = logits + rand
118+
v_1 = 0.0
119+
v_2 = v_0 > v_1
120+
v_3 = tl.cast(v_2, tl.float32)
121+
# src[test_rng.py:N]: out[tile_b, tile_q, :] = torch.matmul(mask, q[tile_b, tile_q, :])
122+
load_1 = tl.load(q + (indices_0[:, None, None] * 512 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), None)
123+
bmm_1 = tl.dot(tl.cast(v_3, tl.float32), tl.cast(load_1, tl.float32), input_precision='tf32', out_dtype=tl.float32)
124+
tl.store(out + (indices_0[:, None, None] * 512 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), bmm_1, None)
125+
126+
def nested_tiles_rand(q: torch.Tensor, *, _launcher=_default_launcher):
127+
from torch._inductor import inductor_prims
128+
# src[test_rng.py:N]: def nested_tiles_rand(q: torch.Tensor) -> torch.Tensor:
129+
# src[test_rng.py:N]: B, T, H = q.shape
130+
# src[test_rng.py:N]: out = torch.empty((B, T, H), device=q.device, dtype=q.dtype)
131+
# src[test_rng.py:N-N]: ...
132+
_rng_seed_buffer = inductor_prims.seeds(1, torch.accelerator.current_accelerator())
133+
# src[test_rng.py:N]: B, T, H = q.shape
134+
B, T, H = q.shape
135+
# src[test_rng.py:N]: out = torch.empty((B, T, H), device=q.device, dtype=q.dtype)
136+
out = torch.empty((B, T, H), device=q.device, dtype=q.dtype)
137+
# src[test_rng.py:N]: for tile_b, tile_q in hl.tile([B, T]):
138+
_BLOCK_SIZE_0 = 2
139+
_BLOCK_SIZE_1 = 16
140+
_RDIM_SIZE_2 = 32
141+
# src[test_rng.py:N]: for tile_k in hl.tile(T):
142+
# src[test_rng.py:N]: ks = q[tile_b, tile_k, :]
143+
# src[test_rng.py:N]: # logits has shape [tile_b, tile_q, tile_k]
144+
# src[test_rng.py:N-N]: ...
145+
_BLOCK_SIZE_3 = 16
146+
# src[test_rng.py:N]: for tile_b, tile_q in hl.tile([B, T]):
147+
# src[test_rng.py:N]: qs = q[tile_b, tile_q, :]
148+
# src[test_rng.py:N]: for tile_k in hl.tile(T):
149+
# src[test_rng.py:N-N]: ...
150+
_launcher(_helion_nested_tiles_rand, (triton.cdiv(2, _BLOCK_SIZE_0) * triton.cdiv(16, _BLOCK_SIZE_1),), q, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, _rng_seed_buffer, num_warps=4, num_stages=1)
151+
# src[test_rng.py:N]: return out
152+
return out
153+
80154
--- assertExpectedJournal(TestRNG.test_rand_like_with_specialized_dimension)
81155
from __future__ import annotations
82156

@@ -150,4 +224,3 @@ def matmul_with_rand(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_lau
150224
_launcher(_helion_matmul_with_rand, (triton.cdiv(256, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=1)
151225
# src[test_rng.py:N]: return out
152226
return out
153-

test/test_rng.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,61 @@ def matmul_with_rand(
506506
# Verify generated code
507507
self.assertExpectedJournal(code)
508508

509+
def test_rand_like_nested_tiles_issue_1208(self):
510+
"""Test torch.rand_like with nested tiles (regression test for issue #1208).
511+
512+
This test reproduces the bug where torch.rand_like() failed with nested tiles
513+
because the RNG codegen incorrectly used dimension indices instead of block_ids
514+
when constructing index variable names.
515+
"""
516+
517+
@helion.kernel(
518+
autotune_effort="none",
519+
static_shapes=True,
520+
ignore_warnings=[helion.exc.TensorOperationInWrapper],
521+
)
522+
def nested_tiles_rand(q: torch.Tensor) -> torch.Tensor:
523+
B, T, H = q.shape
524+
out = torch.empty((B, T, H), device=q.device, dtype=q.dtype)
525+
526+
for tile_b, tile_q in hl.tile([B, T]):
527+
qs = q[tile_b, tile_q, :]
528+
for tile_k in hl.tile(T):
529+
ks = q[tile_b, tile_k, :]
530+
# logits has shape [tile_b, tile_q, tile_k]
531+
# The third dimension uses indices_3 (from the inner loop)
532+
# not indices_2 (from H dimension)
533+
logits = qs @ ks.transpose(-1, -2)
534+
535+
# This used to fail because rand_like incorrectly used
536+
# indices_2 (size H=32) instead of indices_3 (size tile_k=16)
537+
rand = torch.rand_like(logits)
538+
539+
mask = ((logits + rand) > 0).float()
540+
out[tile_b, tile_q, :] = torch.matmul(mask, q[tile_b, tile_q, :])
541+
542+
return out
543+
544+
q = torch.randn(2, 16, 32, device=DEVICE, dtype=torch.float32)
545+
torch.manual_seed(42)
546+
code, result = code_and_output(nested_tiles_rand, (q,))
547+
548+
# Verify output shape
549+
self.assertEqual(result.shape, (2, 16, 32))
550+
551+
# Verify reproducibility
552+
torch.manual_seed(42)
553+
_code2, result2 = code_and_output(nested_tiles_rand, (q,))
554+
torch.testing.assert_close(result, result2)
555+
556+
# Verify different seeds produce different results
557+
torch.manual_seed(123)
558+
_code3, result3 = code_and_output(nested_tiles_rand, (q,))
559+
self.assertFalse(torch.allclose(result, result3))
560+
561+
# Verify generated code
562+
self.assertExpectedJournal(code)
563+
509564

510565
if __name__ == "__main__":
511566
unittest.main()

0 commit comments

Comments
 (0)