@@ -77,6 +77,80 @@ def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher):
7777 # src[test_rng.py:N]: return rand1, rand2, uniform, normal, randn_sum
7878 return (rand1, rand2, uniform, normal, randn_sum)
7979
80+ --- assertExpectedJournal(TestRNG.test_rand_like_nested_tiles_issue_1208)
81+ from __future__ import annotations
82+
83+ import torch
84+ import triton
85+ import triton.language as tl
86+ from helion.runtime import default_launcher as _default_launcher
87+
88+ @triton.jit
89+ def _helion_nested_tiles_rand(q, out, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _RDIM_SIZE_2: tl.constexpr, _BLOCK_SIZE_3: tl.constexpr, rng_seed_buffer):
90+ # src[test_rng.py:N]: for tile_b, tile_q in hl.tile([B, T]):
91+ num_blocks_0 = tl.cdiv(2, _BLOCK_SIZE_0)
92+ pid_0 = tl.program_id(0) % num_blocks_0
93+ pid_1 = tl.program_id(0) // num_blocks_0
94+ offset_0 = pid_0 * _BLOCK_SIZE_0
95+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
96+ offset_1 = pid_1 * _BLOCK_SIZE_1
97+ indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
98+ indices_2 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
99+ # src[test_rng.py:N]: qs = q[tile_b, tile_q, :]
100+ qs = tl.load(q + (indices_0[:, None, None] * 512 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), None)
101+ # src[test_rng.py:N]: for tile_k in hl.tile(T):
102+ # src[test_rng.py:N]: ks = q[tile_b, tile_k, :]
103+ # src[test_rng.py:N]: # logits has shape [tile_b, tile_q, tile_k]
104+ # src[test_rng.py:N-N]: ...
105+ for offset_3 in tl.range(0, 16, _BLOCK_SIZE_3):
106+ indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_3).to(tl.int32)
107+ qs_copy = qs
108+ qs_copy_0 = qs_copy
109+ # src[test_rng.py:N]: ks = q[tile_b, tile_k, :]
110+ ks = tl.load(q + (indices_0[:, None, None] * 512 + indices_3[None, :, None] * 32 + indices_2[None, None, :] * 1), None)
111+ # src[test_rng.py:N]: logits = qs @ ks.transpose(-1, -2)
112+ permute = tl.permute(ks, [0, 2, 1])
113+ logits = tl.dot(tl.cast(qs_copy_0, tl.float32), tl.cast(permute, tl.float32), input_precision='tf32', out_dtype=tl.float32)
114+ # src[test_rng.py:N]: rand = torch.rand_like(logits)
115+ rand = tl.rand(tl.load(rng_seed_buffer + 0), indices_0[:, None, None] * 16 * 16 + indices_1[None, :, None] * 16 + indices_3[None, None, :]).to(tl.float32)
116+ # src[test_rng.py:N]: mask = ((logits + rand) > 0).float()
117+ v_0 = logits + rand
118+ v_1 = 0.0
119+ v_2 = v_0 > v_1
120+ v_3 = tl.cast(v_2, tl.float32)
121+ # src[test_rng.py:N]: out[tile_b, tile_q, :] = torch.matmul(mask, q[tile_b, tile_q, :])
122+ load_1 = tl.load(q + (indices_0[:, None, None] * 512 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), None)
123+ bmm_1 = tl.dot(tl.cast(v_3, tl.float32), tl.cast(load_1, tl.float32), input_precision='tf32', out_dtype=tl.float32)
124+ tl.store(out + (indices_0[:, None, None] * 512 + indices_1[None, :, None] * 32 + indices_2[None, None, :] * 1), bmm_1, None)
125+
126+ def nested_tiles_rand(q: torch.Tensor, *, _launcher=_default_launcher):
127+ from torch._inductor import inductor_prims
128+ # src[test_rng.py:N]: def nested_tiles_rand(q: torch.Tensor) -> torch.Tensor:
129+ # src[test_rng.py:N]: B, T, H = q.shape
130+ # src[test_rng.py:N]: out = torch.empty((B, T, H), device=q.device, dtype=q.dtype)
131+ # src[test_rng.py:N-N]: ...
132+ _rng_seed_buffer = inductor_prims.seeds(1, torch.accelerator.current_accelerator())
133+ # src[test_rng.py:N]: B, T, H = q.shape
134+ B, T, H = q.shape
135+ # src[test_rng.py:N]: out = torch.empty((B, T, H), device=q.device, dtype=q.dtype)
136+ out = torch.empty((B, T, H), device=q.device, dtype=q.dtype)
137+ # src[test_rng.py:N]: for tile_b, tile_q in hl.tile([B, T]):
138+ _BLOCK_SIZE_0 = 2
139+ _BLOCK_SIZE_1 = 16
140+ _RDIM_SIZE_2 = 32
141+ # src[test_rng.py:N]: for tile_k in hl.tile(T):
142+ # src[test_rng.py:N]: ks = q[tile_b, tile_k, :]
143+ # src[test_rng.py:N]: # logits has shape [tile_b, tile_q, tile_k]
144+ # src[test_rng.py:N-N]: ...
145+ _BLOCK_SIZE_3 = 16
146+ # src[test_rng.py:N]: for tile_b, tile_q in hl.tile([B, T]):
147+ # src[test_rng.py:N]: qs = q[tile_b, tile_q, :]
148+ # src[test_rng.py:N]: for tile_k in hl.tile(T):
149+ # src[test_rng.py:N-N]: ...
150+ _launcher(_helion_nested_tiles_rand, (triton.cdiv(2, _BLOCK_SIZE_0) * triton.cdiv(16, _BLOCK_SIZE_1),), q, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _RDIM_SIZE_2, _BLOCK_SIZE_3, _rng_seed_buffer, num_warps=4, num_stages=1)
151+ # src[test_rng.py:N]: return out
152+ return out
153+
80154--- assertExpectedJournal(TestRNG.test_rand_like_with_specialized_dimension)
81155from __future__ import annotations
82156
@@ -150,4 +224,3 @@ def matmul_with_rand(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_lau
150224 _launcher(_helion_matmul_with_rand, (triton.cdiv(256, _BLOCK_SIZE_0),), x, y, out, _BLOCK_SIZE_0, _RDIM_SIZE_2, _BLOCK_SIZE_1, _rng_seed_buffer, num_warps=4, num_stages=1)
151225 # src[test_rng.py:N]: return out
152226 return out
153-
0 commit comments