66from enum import Enum , IntEnum
77from typing import Any , Literal
88
9+ from .._datatype import uint32 , uint64
910from cuda .lang ._execution import stub
10- from .bits import set_bit , set_bits
11+ from .bits import set_bit32 , set_bit64 , set_bits32 , set_bits64
1112from .nvvm import P3 , P6
1213
1314
@@ -155,20 +156,20 @@ class Tcgen05InstructionDescriptor:
155156 max_shift : MaxShift = MaxShift .NoShift
156157
157158 def encode (self ) -> int :
158- desc = 0
159- desc = set_bits (desc , self .sparsity_selector , 0 , 2 )
160- desc = set_bit (desc , 2 , self .sparse )
161- desc = set_bit (desc , 3 , self .saturate )
162- desc = set_bits (desc , self .d_type , 4 , 2 )
163- desc = set_bits (desc , self .a_type , 7 , 3 )
164- desc = set_bits (desc , self .b_type , 10 , 3 )
165- desc = set_bit (desc , 13 , self .negate_a )
166- desc = set_bit (desc , 14 , self .negate_b )
167- desc = set_bit (desc , 15 , self .transpose_a )
168- desc = set_bit (desc , 16 , self .transpose_b )
169- desc = set_bits (desc , self .n >> 3 , 17 , 6 )
170- desc = set_bits (desc , self .m >> 4 , 24 , 5 )
171- desc = set_bits (desc , self .max_shift , 30 , 2 )
159+ desc = uint32 ( 0x0000_0000 )
160+ desc = set_bits32 (desc , self .sparsity_selector , 0 , 2 )
161+ desc = set_bit32 (desc , 2 , self .sparse )
162+ desc = set_bit32 (desc , 3 , self .saturate )
163+ desc = set_bits32 (desc , self .d_type , 4 , 2 )
164+ desc = set_bits32 (desc , self .a_type , 7 , 3 )
165+ desc = set_bits32 (desc , self .b_type , 10 , 3 )
166+ desc = set_bit32 (desc , 13 , self .negate_a )
167+ desc = set_bit32 (desc , 14 , self .negate_b )
168+ desc = set_bit32 (desc , 15 , self .transpose_a )
169+ desc = set_bit32 (desc , 16 , self .transpose_b )
170+ desc = set_bits32 (desc , self .n >> 3 , 17 , 6 )
171+ desc = set_bits32 (desc , self .m >> 4 , 24 , 5 )
172+ desc = set_bits32 (desc , self .max_shift , 30 , 2 )
172173 return desc
173174
174175
@@ -193,19 +194,19 @@ class Tcgen05Mxf8f6f4InstructionDescriptor:
193194 a_scale_id : Literal [0 , 1 , 2 , 3 ] = 0
194195
195196 def encode (self ) -> int :
196- desc = 0
197- desc = set_bit (desc , 2 , self .sparse )
198- desc = set_bits (desc , self .b_scale_id , 4 , 2 )
199- desc = set_bits (desc , self .a_type , 7 , 3 )
200- desc = set_bits (desc , self .b_type , 10 , 3 )
201- desc = set_bit (desc , 13 , self .negate_a )
202- desc = set_bit (desc , 14 , self .negate_b )
203- desc = set_bit (desc , 15 , self .transpose_a )
204- desc = set_bit (desc , 16 , self .transpose_b )
205- desc = set_bits (desc , self .n >> 3 , 17 , 6 )
206- desc = set_bit (desc , 23 , self .scale_format )
207- desc = set_bits (desc , self .m >> 7 , 27 , 2 )
208- desc = set_bits (desc , self .a_scale_id , 29 , 2 )
197+ desc = uint32 ( 0x0000_0000 )
198+ desc = set_bit32 (desc , 2 , self .sparse )
199+ desc = set_bits32 (desc , self .b_scale_id , 4 , 2 )
200+ desc = set_bits32 (desc , self .a_type , 7 , 3 )
201+ desc = set_bits32 (desc , self .b_type , 10 , 3 )
202+ desc = set_bit32 (desc , 13 , self .negate_a )
203+ desc = set_bit32 (desc , 14 , self .negate_b )
204+ desc = set_bit32 (desc , 15 , self .transpose_a )
205+ desc = set_bit32 (desc , 16 , self .transpose_b )
206+ desc = set_bits32 (desc , self .n >> 3 , 17 , 6 )
207+ desc = set_bit32 (desc , 23 , self .scale_format )
208+ desc = set_bits32 (desc , self .m >> 7 , 27 , 2 )
209+ desc = set_bits32 (desc , self .a_scale_id , 29 , 2 )
209210 return desc
210211
211212
@@ -232,20 +233,52 @@ class Tcgen05Mxf4InstructionDescriptor:
232233 k_dimension : KDimension = KDimension .DenseK64OrSparseK128
233234
234235 def encode (self ) -> int :
235- desc = 0
236- desc = set_bit (desc , 2 , self .sparse )
237- desc = set_bits (desc , self .b_scale_id , 4 , 2 )
238- desc = set_bits (desc , self .a_type , 7 , 3 )
239- desc = set_bits (desc , self .b_type , 10 , 2 )
240- desc = set_bit (desc , 13 , self .negate_a )
241- desc = set_bit (desc , 14 , self .negate_b )
242- desc = set_bit (desc , 15 , self .transpose_a )
243- desc = set_bit (desc , 16 , self .transpose_b )
244- desc = set_bits (desc , self .n >> 3 , 17 , 6 )
245- desc = set_bit (desc , 23 , self .scale_format )
246- desc = set_bits (desc , self .m >> 7 , 27 , 2 )
247- desc = set_bits (desc , self .a_scale_id , 29 , 2 )
248- desc = set_bit (desc , 31 , self .k_dimension )
236+ desc = uint32 (0x0000_0000 )
237+ desc = set_bit32 (desc , 2 , self .sparse )
238+ desc = set_bits32 (desc , self .b_scale_id , 4 , 2 )
239+ desc = set_bits32 (desc , self .a_type , 7 , 3 )
240+ desc = set_bits32 (desc , self .b_type , 10 , 2 )
241+ desc = set_bit32 (desc , 13 , self .negate_a )
242+ desc = set_bit32 (desc , 14 , self .negate_b )
243+ desc = set_bit32 (desc , 15 , self .transpose_a )
244+ desc = set_bit32 (desc , 16 , self .transpose_b )
245+ desc = set_bits32 (desc , self .n >> 3 , 17 , 6 )
246+ desc = set_bit32 (desc , 23 , self .scale_format )
247+ desc = set_bits32 (desc , self .m >> 7 , 27 , 2 )
248+ desc = set_bits32 (desc , self .a_scale_id , 29 , 2 )
249+ desc = set_bit32 (desc , 31 , self .k_dimension )
250+ return desc
251+
252+
253+ @dataclass (frozen = True )
254+ class Tcgen05SharedMemoryDescriptor :
255+ class LeadingDimMode (IntEnum ):
256+ ByteOffsetRelative = 0
257+ ByteAddressAbsolute = 1
258+
259+ class SwizzleMode (IntEnum ):
260+ NoSwizzling = 0
261+ Swizzle128B32BAtomic = 1
262+ Swizzle128B = 2
263+ Swizzle64B = 4
264+ Swizzle32B = 6
265+
266+ matrix_start_address : int
267+ leading_dim_offset : int
268+ stride_dim_offset : int
269+ base_offset : int = 0
270+ leading_dim_mode : LeadingDimMode = LeadingDimMode .ByteOffsetRelative
271+ swizzle_mode : SwizzleMode = SwizzleMode .NoSwizzling
272+
273+ def encode (self ) -> int :
274+ desc = uint64 (0x0000_0000_0000_0000 )
275+ desc = set_bits64 (desc , (self .matrix_start_address & 0x3FFFF ) >> 4 , 0 , 14 )
276+ desc = set_bits64 (desc , (self .leading_dim_offset & 0x3FFFF ) >> 4 , 16 , 14 )
277+ desc = set_bits64 (desc , (self .stride_dim_offset & 0x3FFFF ) >> 4 , 32 , 14 )
278+ desc = set_bits64 (desc , 0b001 , 46 , 3 )
279+ desc = set_bits64 (desc , self .base_offset , 49 , 3 )
280+ desc = set_bit64 (desc , 52 , self .leading_dim_mode )
281+ desc = set_bits64 (desc , self .swizzle_mode , 61 , 3 )
249282 return desc
250283
251284
@@ -255,6 +288,7 @@ def encode(self) -> int:
255288 "Tcgen05InstructionDescriptor" ,
256289 "Tcgen05Mxf8f6f4InstructionDescriptor" ,
257290 "Tcgen05Mxf4InstructionDescriptor" ,
291+ "Tcgen05SharedMemoryDescriptor" ,
258292 "tcgen05_alloc" ,
259293 "tcgen05_dealloc" ,
260294 "tcgen05_commit" ,
0 commit comments