Skip to content

Commit 663bd51

Browse files
[lang] Add tcgen05 shared memory descriptor encoder
Signed-off-by: Asher Mancinelli <amancinelli@nvidia.com>
1 parent dc5a949 commit 663bd51

4 files changed

Lines changed: 115 additions & 53 deletions

File tree

experimental/cuda-lang/src/cuda/lang/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
Tcgen05InstructionDescriptor,
9494
Tcgen05Mxf8f6f4InstructionDescriptor,
9595
Tcgen05Mxf4InstructionDescriptor,
96+
Tcgen05SharedMemoryDescriptor,
9697
tcgen05_alloc,
9798
tcgen05_dealloc,
9899
tcgen05_commit,
@@ -221,6 +222,7 @@
221222
"Tcgen05InstructionDescriptor",
222223
"Tcgen05Mxf8f6f4InstructionDescriptor",
223224
"Tcgen05Mxf4InstructionDescriptor",
225+
"Tcgen05SharedMemoryDescriptor",
224226
"tcgen05_alloc",
225227
"tcgen05_dealloc",
226228
"tcgen05_commit",

experimental/cuda-lang/src/cuda/lang/_stub/bits.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,25 @@
55
from ._nvvm_support import IX, B
66

77

8-
def get_bits(value: IX, position: IX, width: IX) -> IX:
9-
mask = (1 << width) - 1
10-
return (value >> position) & mask
11-
12-
13-
def set_bits(value: IX, field: IX, position: IX, width: IX) -> IX:
8+
def set_bits_fixed(value: IX, field: IX, position: IX, width: IX, full_mask: IX) -> IX:
149
field_mask = (1 << width) - 1
1510
mask = field_mask << position
16-
return (value & ~mask) | ((field & field_mask) << position)
11+
clear_mask = full_mask - mask
12+
insert = (field & field_mask) << position
13+
return (value & clear_mask) | insert
14+
15+
16+
def set_bits32(value: IX, field: IX, position: IX, width: IX) -> IX:
17+
return set_bits_fixed(value, field, position, width, 0xFFFF_FFFF)
18+
19+
20+
def set_bits64(value: IX, field: IX, position: IX, width: IX) -> IX:
21+
return set_bits_fixed(value, field, position, width, 0xFFFF_FFFF_FFFF_FFFF)
1722

1823

19-
def get_bit(value: IX, position: IX) -> B:
20-
return get_bits(value, position, 1)
24+
def set_bit32(value: IX, position: IX, bit: B = 1) -> IX:
25+
return set_bits32(value, bit, position, 1)
2126

2227

23-
def set_bit(value: IX, position: IX, bit: B = 1) -> IX:
24-
return set_bits(value, bit, position, 1)
28+
def set_bit64(value: IX, position: IX, bit: B = 1) -> IX:
29+
return set_bits64(value, bit, position, 1)

experimental/cuda-lang/src/cuda/lang/_stub/tcgen05.py

Lines changed: 76 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from enum import Enum, IntEnum
77
from typing import Any, Literal
88

9+
from .._datatype import uint32, uint64
910
from cuda.lang._execution import stub
10-
from .bits import set_bit, set_bits
11+
from .bits import set_bit32, set_bit64, set_bits32, set_bits64
1112
from .nvvm import P3, P6
1213

1314

@@ -155,20 +156,20 @@ class Tcgen05InstructionDescriptor:
155156
max_shift: MaxShift = MaxShift.NoShift
156157

157158
def encode(self) -> int:
158-
desc = 0
159-
desc = set_bits(desc, self.sparsity_selector, 0, 2)
160-
desc = set_bit(desc, 2, self.sparse)
161-
desc = set_bit(desc, 3, self.saturate)
162-
desc = set_bits(desc, self.d_type, 4, 2)
163-
desc = set_bits(desc, self.a_type, 7, 3)
164-
desc = set_bits(desc, self.b_type, 10, 3)
165-
desc = set_bit(desc, 13, self.negate_a)
166-
desc = set_bit(desc, 14, self.negate_b)
167-
desc = set_bit(desc, 15, self.transpose_a)
168-
desc = set_bit(desc, 16, self.transpose_b)
169-
desc = set_bits(desc, self.n >> 3, 17, 6)
170-
desc = set_bits(desc, self.m >> 4, 24, 5)
171-
desc = set_bits(desc, self.max_shift, 30, 2)
159+
desc = uint32(0x0000_0000)
160+
desc = set_bits32(desc, self.sparsity_selector, 0, 2)
161+
desc = set_bit32(desc, 2, self.sparse)
162+
desc = set_bit32(desc, 3, self.saturate)
163+
desc = set_bits32(desc, self.d_type, 4, 2)
164+
desc = set_bits32(desc, self.a_type, 7, 3)
165+
desc = set_bits32(desc, self.b_type, 10, 3)
166+
desc = set_bit32(desc, 13, self.negate_a)
167+
desc = set_bit32(desc, 14, self.negate_b)
168+
desc = set_bit32(desc, 15, self.transpose_a)
169+
desc = set_bit32(desc, 16, self.transpose_b)
170+
desc = set_bits32(desc, self.n >> 3, 17, 6)
171+
desc = set_bits32(desc, self.m >> 4, 24, 5)
172+
desc = set_bits32(desc, self.max_shift, 30, 2)
172173
return desc
173174

174175

@@ -193,19 +194,19 @@ class Tcgen05Mxf8f6f4InstructionDescriptor:
193194
a_scale_id: Literal[0, 1, 2, 3] = 0
194195

195196
def encode(self) -> int:
196-
desc = 0
197-
desc = set_bit(desc, 2, self.sparse)
198-
desc = set_bits(desc, self.b_scale_id, 4, 2)
199-
desc = set_bits(desc, self.a_type, 7, 3)
200-
desc = set_bits(desc, self.b_type, 10, 3)
201-
desc = set_bit(desc, 13, self.negate_a)
202-
desc = set_bit(desc, 14, self.negate_b)
203-
desc = set_bit(desc, 15, self.transpose_a)
204-
desc = set_bit(desc, 16, self.transpose_b)
205-
desc = set_bits(desc, self.n >> 3, 17, 6)
206-
desc = set_bit(desc, 23, self.scale_format)
207-
desc = set_bits(desc, self.m >> 7, 27, 2)
208-
desc = set_bits(desc, self.a_scale_id, 29, 2)
197+
desc = uint32(0x0000_0000)
198+
desc = set_bit32(desc, 2, self.sparse)
199+
desc = set_bits32(desc, self.b_scale_id, 4, 2)
200+
desc = set_bits32(desc, self.a_type, 7, 3)
201+
desc = set_bits32(desc, self.b_type, 10, 3)
202+
desc = set_bit32(desc, 13, self.negate_a)
203+
desc = set_bit32(desc, 14, self.negate_b)
204+
desc = set_bit32(desc, 15, self.transpose_a)
205+
desc = set_bit32(desc, 16, self.transpose_b)
206+
desc = set_bits32(desc, self.n >> 3, 17, 6)
207+
desc = set_bit32(desc, 23, self.scale_format)
208+
desc = set_bits32(desc, self.m >> 7, 27, 2)
209+
desc = set_bits32(desc, self.a_scale_id, 29, 2)
209210
return desc
210211

211212

@@ -232,20 +233,52 @@ class Tcgen05Mxf4InstructionDescriptor:
232233
k_dimension: KDimension = KDimension.DenseK64OrSparseK128
233234

234235
def encode(self) -> int:
235-
desc = 0
236-
desc = set_bit(desc, 2, self.sparse)
237-
desc = set_bits(desc, self.b_scale_id, 4, 2)
238-
desc = set_bits(desc, self.a_type, 7, 3)
239-
desc = set_bits(desc, self.b_type, 10, 2)
240-
desc = set_bit(desc, 13, self.negate_a)
241-
desc = set_bit(desc, 14, self.negate_b)
242-
desc = set_bit(desc, 15, self.transpose_a)
243-
desc = set_bit(desc, 16, self.transpose_b)
244-
desc = set_bits(desc, self.n >> 3, 17, 6)
245-
desc = set_bit(desc, 23, self.scale_format)
246-
desc = set_bits(desc, self.m >> 7, 27, 2)
247-
desc = set_bits(desc, self.a_scale_id, 29, 2)
248-
desc = set_bit(desc, 31, self.k_dimension)
236+
desc = uint32(0x0000_0000)
237+
desc = set_bit32(desc, 2, self.sparse)
238+
desc = set_bits32(desc, self.b_scale_id, 4, 2)
239+
desc = set_bits32(desc, self.a_type, 7, 3)
240+
desc = set_bits32(desc, self.b_type, 10, 2)
241+
desc = set_bit32(desc, 13, self.negate_a)
242+
desc = set_bit32(desc, 14, self.negate_b)
243+
desc = set_bit32(desc, 15, self.transpose_a)
244+
desc = set_bit32(desc, 16, self.transpose_b)
245+
desc = set_bits32(desc, self.n >> 3, 17, 6)
246+
desc = set_bit32(desc, 23, self.scale_format)
247+
desc = set_bits32(desc, self.m >> 7, 27, 2)
248+
desc = set_bits32(desc, self.a_scale_id, 29, 2)
249+
desc = set_bit32(desc, 31, self.k_dimension)
250+
return desc
251+
252+
253+
@dataclass(frozen=True)
254+
class Tcgen05SharedMemoryDescriptor:
255+
class LeadingDimMode(IntEnum):
256+
ByteOffsetRelative = 0
257+
ByteAddressAbsolute = 1
258+
259+
class SwizzleMode(IntEnum):
260+
NoSwizzling = 0
261+
Swizzle128B32BAtomic = 1
262+
Swizzle128B = 2
263+
Swizzle64B = 4
264+
Swizzle32B = 6
265+
266+
matrix_start_address: int
267+
leading_dim_offset: int
268+
stride_dim_offset: int
269+
base_offset: int = 0
270+
leading_dim_mode: LeadingDimMode = LeadingDimMode.ByteOffsetRelative
271+
swizzle_mode: SwizzleMode = SwizzleMode.NoSwizzling
272+
273+
def encode(self) -> int:
274+
desc = uint64(0x0000_0000_0000_0000)
275+
desc = set_bits64(desc, (self.matrix_start_address & 0x3FFFF) >> 4, 0, 14)
276+
desc = set_bits64(desc, (self.leading_dim_offset & 0x3FFFF) >> 4, 16, 14)
277+
desc = set_bits64(desc, (self.stride_dim_offset & 0x3FFFF) >> 4, 32, 14)
278+
desc = set_bits64(desc, 0b001, 46, 3)
279+
desc = set_bits64(desc, self.base_offset, 49, 3)
280+
desc = set_bit64(desc, 52, self.leading_dim_mode)
281+
desc = set_bits64(desc, self.swizzle_mode, 61, 3)
249282
return desc
250283

251284

@@ -255,6 +288,7 @@ def encode(self) -> int:
255288
"Tcgen05InstructionDescriptor",
256289
"Tcgen05Mxf8f6f4InstructionDescriptor",
257290
"Tcgen05Mxf4InstructionDescriptor",
291+
"Tcgen05SharedMemoryDescriptor",
258292
"tcgen05_alloc",
259293
"tcgen05_dealloc",
260294
"tcgen05_commit",

experimental/cuda-lang/test/test_tcgen05_descriptors.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ def encode_tcgen05_mxf4_instruction_descriptor():
6161
).encode()
6262

6363

64+
def encode_tcgen05_shared_memory_descriptor():
65+
return cl.Tcgen05SharedMemoryDescriptor(
66+
matrix_start_address=0x12340,
67+
leading_dim_offset=0x23450,
68+
stride_dim_offset=0x34560,
69+
base_offset=5,
70+
leading_dim_mode=cl.Tcgen05SharedMemoryDescriptor.LeadingDimMode.ByteAddressAbsolute,
71+
swizzle_mode=cl.Tcgen05SharedMemoryDescriptor.SwizzleMode.Swizzle128B,
72+
).encode()
73+
74+
6475
@pytest.mark.parametrize(
6576
"encode_descriptor,expected",
6677
[
@@ -105,6 +116,16 @@ def encode_tcgen05_mxf4_instruction_descriptor():
105116
| (2 << 29)
106117
| (1 << 31),
107118
),
119+
(
120+
encode_tcgen05_shared_memory_descriptor,
121+
(((0x12340 & 0x3FFFF) >> 4) << 0)
122+
| (((0x23450 & 0x3FFFF) >> 4) << 16)
123+
| (((0x34560 & 0x3FFFF) >> 4) << 32)
124+
| (0b001 << 46)
125+
| (5 << 49)
126+
| (1 << 52)
127+
| (2 << 61),
128+
),
108129
],
109130
)
110131
def test_tcgen05_instruction_descriptor_encode_on_gpu(encode_descriptor, expected):

0 commit comments

Comments
 (0)