diff --git a/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh b/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh index 3fbaf40f8..da6e1810c 100644 --- a/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh +++ b/src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh @@ -16,7 +16,8 @@ namespace op::paged_attention_prefill::cuda { -__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cu_seqlens_q, size_t num_seqs) { +template +__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const Tindex *cu_seqlens_q, size_t num_seqs) { size_t low = 0, high = (num_seqs == 0) ? 0 : (num_seqs - 1); while (low <= high) { size_t mid = (low + high) >> 1; @@ -43,8 +44,8 @@ __device__ void PagedAttentionPrefillWarpKernel( const Tdata *k_cache_, const Tdata *v_cache_, const Tindex *block_tables_, - const int64_t *total_kv_lens_, - const int64_t *cu_seqlens_q_, + const Tindex *total_kv_lens_, + const Tindex *cu_seqlens_q_, const float *alibi_slopes_, size_t num_kv_heads, float scale, @@ -73,8 +74,8 @@ __device__ void PagedAttentionPrefillWarpKernel( const int seq_idx = static_cast(blockIdx.y); const int q_token_local = static_cast(blockIdx.z); - const int64_t q_start = cu_seqlens_q_[seq_idx]; - const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const Tindex q_start = cu_seqlens_q_[seq_idx]; + const Tindex q_end = cu_seqlens_q_[seq_idx + 1]; const int q_len = static_cast(q_end - q_start); if (q_token_local >= q_len) { return; @@ -256,8 +257,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel( const Tdata *k_cache_, const Tdata *v_cache_, const Tindex *block_tables_, - const int64_t *total_kv_lens_, - const int64_t *cu_seqlens_q_, + const Tindex *total_kv_lens_, + const Tindex *cu_seqlens_q_, const float *alibi_slopes_, size_t num_heads, size_t num_seqs, @@ -291,9 +292,9 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel( return; } - const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs); - const int64_t q_start = cu_seqlens_q_[seq_idx]; - const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs); + const Tindex q_start = cu_seqlens_q_[seq_idx]; + const Tindex q_end = cu_seqlens_q_[seq_idx + 1]; const int q_len = static_cast(q_end - q_start); const int q_token_local = static_cast(global_token_idx - static_cast(q_start)); @@ -477,8 +478,8 @@ __global__ void PagedAttentionPrefillReferenceKernel( const Tdata *k_cache_, const Tdata *v_cache_, const Tindex *block_tables_, - const int64_t *total_kv_lens_, - const int64_t *cu_seqlens_q_, + const Tindex *total_kv_lens_, + const Tindex *cu_seqlens_q_, const float *alibi_slopes_, size_t num_heads, size_t num_kv_heads, @@ -506,7 +507,7 @@ __global__ void PagedAttentionPrefillReferenceKernel( return; } - const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs); + const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs); const size_t q_token_idx = global_token_idx - static_cast(cu_seqlens_q_[seq_idx]); const size_t q_len = static_cast(cu_seqlens_q_[seq_idx + 1] - cu_seqlens_q_[seq_idx]); @@ -595,8 +596,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel( const Tdata *k_cache_, const Tdata *v_cache_, const Tindex *block_tables_, - const int64_t *total_kv_lens_, - const int64_t *cu_seqlens_q_, + const Tindex *total_kv_lens_, + const Tindex *cu_seqlens_q_, const float *alibi_slopes_, size_t num_kv_heads, float scale, @@ -632,8 +633,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel( const int seq_idx = static_cast(blockIdx.y); const int m_block = static_cast(blockIdx.z); - const int64_t q_start = cu_seqlens_q_[seq_idx]; - const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const Tindex q_start = cu_seqlens_q_[seq_idx]; + const Tindex q_end = cu_seqlens_q_[seq_idx + 1]; const int q_len = static_cast(q_end - q_start); if (q_len <= 0) { return; @@ -865,8 +866,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined( const Tdata *k_cache_, const Tdata *v_cache_, const Tindex *block_tables_, - const int64_t *total_kv_lens_, - const int64_t *cu_seqlens_q_, + const Tindex *total_kv_lens_, + const Tindex *cu_seqlens_q_, const float *alibi_slopes_, size_t num_kv_heads, float scale, @@ -904,8 +905,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined( const int seq_idx = static_cast(blockIdx.y); const int m_block = static_cast(blockIdx.z); - const int64_t q_start = cu_seqlens_q_[seq_idx]; - const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const Tindex q_start = cu_seqlens_q_[seq_idx]; + const Tindex q_end = cu_seqlens_q_[seq_idx + 1]; const int q_len = static_cast(q_end - q_start); if (q_len <= 0) { return; @@ -1312,8 +1313,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv( const Tdata *k_cache_, const Tdata *v_cache_, const Tindex *block_tables_, - const int64_t *total_kv_lens_, - const int64_t *cu_seqlens_q_, + const Tindex *total_kv_lens_, + const Tindex *cu_seqlens_q_, const float *alibi_slopes_, size_t num_kv_heads, float scale, @@ -1350,8 +1351,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv( const int head_idx = static_cast(blockIdx.x); const int seq_idx = static_cast(blockIdx.y); - const int64_t q_start = cu_seqlens_q_[seq_idx]; - const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const Tindex q_start = cu_seqlens_q_[seq_idx]; + const Tindex q_end = cu_seqlens_q_[seq_idx + 1]; const int q_len = static_cast(q_end - q_start); if (q_len <= 0) { return; @@ -1778,8 +1779,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly( const Tdata *k_cache_, const Tdata *v_cache_, const Tindex *block_tables_, - const int64_t *total_kv_lens_, - const int64_t *cu_seqlens_q_, + const Tindex *total_kv_lens_, + const Tindex *cu_seqlens_q_, const float *alibi_slopes_, size_t num_kv_heads, float scale, @@ -1815,8 +1816,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly( const int seq_idx = static_cast(blockIdx.y); const int m_block = static_cast(blockIdx.z); - const int64_t q_start = cu_seqlens_q_[seq_idx]; - const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const Tindex q_start = cu_seqlens_q_[seq_idx]; + const Tindex q_end = cu_seqlens_q_[seq_idx + 1]; const int q_len = static_cast(q_end - q_start); if (q_len <= 0) { return; @@ -2115,12 +2116,12 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreUpdateRow( } } -template +template __device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow( int lane, bool active, int q_token_local, - int64_t q_start, + Tindex q_start, int head_idx, half *out_, ptrdiff_t o_stride, @@ -2153,8 +2154,8 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel( const half *k_cache_, const half *v_cache_, const Tindex *block_tables_, - const int64_t *total_kv_lens_, - const int64_t *cu_seqlens_q_, + const Tindex *total_kv_lens_, + const Tindex *cu_seqlens_q_, const float *alibi_slopes_, size_t num_kv_heads, float scale, @@ -2198,8 +2199,8 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel( const int seq_idx = static_cast(blockIdx.y); const int m_block = static_cast(blockIdx.z); - const int64_t q_start = cu_seqlens_q_[seq_idx]; - const int64_t q_end = cu_seqlens_q_[seq_idx + 1]; + const Tindex q_start = cu_seqlens_q_[seq_idx]; + const Tindex q_end = cu_seqlens_q_[seq_idx + 1]; const int q_len = static_cast(q_end - q_start); if (q_len <= 0) { return; @@ -2353,11 +2354,11 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel( // Write outputs. if (row0 < kBlockM) { - PagedAttentionPrefillMmaScoreWriteRow( + PagedAttentionPrefillMmaScoreWriteRow( lane, active0, m_start + row0, q_start, head_idx, out_, o_stride, o_head_stride, l0, acc0); } if (row1 < kBlockM) { - PagedAttentionPrefillMmaScoreWriteRow( + PagedAttentionPrefillMmaScoreWriteRow( lane, active1, m_start + row1, q_start, head_idx, out_, o_stride, o_head_stride, l1, acc1); } } diff --git a/src/infiniop/ops/paged_attention_prefill/info.h b/src/infiniop/ops/paged_attention_prefill/info.h index a40f4ceaf..9f1307c3b 100644 --- a/src/infiniop/ops/paged_attention_prefill/info.h +++ b/src/infiniop/ops/paged_attention_prefill/info.h @@ -80,9 +80,13 @@ class PagedAttentionPrefillInfo { return INFINI_STATUS_BAD_TENSOR_DTYPE; } - // Keep it simple: require total_kv_lens + cum_seqlens_q to be int64 for now - // (matches current paged_attention_prefill signature). We will convert to int32 internally later. - if (total_kv_lens_desc->dtype() != INFINI_DTYPE_I64 || cum_seqlens_q_desc->dtype() != INFINI_DTYPE_I64) { + // Index tensors use int32_t to match mainstream paged-attention implementations + // (e.g., vLLM / FlashAttention2). 32-bit indices needed, but now we also support int64_t. + if (!((total_kv_lens_desc->dtype() == INFINI_DTYPE_I64) || (total_kv_lens_desc->dtype() == INFINI_DTYPE_I32) || (total_kv_lens_desc->dtype() == INFINI_DTYPE_U32))) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + if (!((cum_seqlens_q_desc->dtype() == INFINI_DTYPE_I64) || (cum_seqlens_q_desc->dtype() == INFINI_DTYPE_I32) || (cum_seqlens_q_desc->dtype() == INFINI_DTYPE_U32))) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu index 04107fb4e..042de64b2 100644 --- a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu +++ b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu @@ -47,8 +47,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128Warp( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -80,8 +80,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64Warp( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -113,8 +113,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -149,8 +149,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -185,8 +185,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -221,8 +221,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8N128( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -258,8 +258,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -294,8 +294,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Pipe( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -330,8 +330,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8Mma( const half *k_cache, const half *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -365,8 +365,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8Pipe( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -405,8 +405,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta8PipeSplitKv( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -448,8 +448,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta8PipeSplitKv( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -513,8 +513,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd128WarpCta16( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -549,8 +549,8 @@ INFINIOP_CUDA_KERNEL PagedAttentionPrefillHd64WarpCta16( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_kv_heads, float scale, @@ -585,8 +585,8 @@ infiniStatus_t launch_prefill_ref( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_heads, size_t num_seqs, @@ -646,8 +646,8 @@ infiniStatus_t launch_prefill_warp( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_heads, size_t num_seqs, @@ -713,8 +713,8 @@ infiniStatus_t launch_prefill( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_heads, size_t num_seqs, @@ -779,8 +779,8 @@ infiniStatus_t launch_prefill_warpcta8( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_heads, size_t num_seqs, @@ -845,8 +845,8 @@ infiniStatus_t launch_prefill_warpcta8pipe( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_heads, size_t num_seqs, @@ -911,8 +911,8 @@ infiniStatus_t launch_prefill_warpcta8mma( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_heads, size_t num_seqs, @@ -1028,8 +1028,8 @@ infiniStatus_t launch_prefill_warpcta8pipe_splitkv( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_heads, size_t num_seqs, @@ -1123,8 +1123,8 @@ infiniStatus_t launch_prefill_warpcta8n128( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_heads, size_t num_seqs, @@ -1178,8 +1178,8 @@ infiniStatus_t launch_prefill_warpcta16( const Tdata *k_cache, const Tdata *v_cache, const Tindex *block_tables, - const int64_t *total_kv_lens, - const int64_t *cu_seqlens_q, + const Tindex *total_kv_lens, + const Tindex *cu_seqlens_q, const float *alibi_slopes, size_t num_heads, size_t num_seqs, @@ -1311,8 +1311,10 @@ infiniStatus_t Descriptor::calculate( auto stream = static_cast(stream_); const float *alibi_ptr = (alibi_slopes == nullptr) ? nullptr : static_cast(alibi_slopes); - const auto *total_kv_lens_i64 = static_cast(total_kv_lens); - const auto *cu_seqlens_q_i64 = static_cast(cum_seqlens_q); + // const auto *total_kv_lens_i64 = static_cast(total_kv_lens); + // const auto *cu_seqlens_q_i64 = static_cast(cum_seqlens_q); + const void *total_kv_lens_i64 = total_kv_lens; + const void *cu_seqlens_q_i64 = cum_seqlens_q; bool use_splitkv = false; if (const char *env = std::getenv("INFINIOP_FLASH_PREFILL_SPLITKV")) { @@ -1346,7 +1348,7 @@ infiniStatus_t Descriptor::calculate( float *partial_m = partial_acc + static_cast(num_splits) * n * _info.head_size; float *partial_l = partial_m + static_cast(num_splits) * n; - // Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are currently always int64. + // Dispatch by (Tdata, Tindex). total_kv_lens + cu_seqlens_q are always int32, but now we also support int64_t. #define DISPATCH_SPLITKV(Tindex, Tdata, BT_PTR) \ return launch_prefill_warpcta8pipe_splitkv( \ partial_acc, partial_m, partial_l, num_splits, \ @@ -1355,7 +1357,9 @@ infiniStatus_t Descriptor::calculate( static_cast(k_cache), \ static_cast(v_cache), \ static_cast(BT_PTR), \ - total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + static_cast(total_kv_lens_i64), \ + static_cast(cu_seqlens_q_i64), \ + alibi_ptr, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.block_table_batch_stride, \ @@ -1363,7 +1367,6 @@ infiniStatus_t Descriptor::calculate( _info.k_batch_stride, _info.k_row_stride, _info.k_head_stride, \ _info.v_batch_stride, _info.v_row_stride, _info.v_head_stride, \ _info.o_stride, _info.o_head_stride, stream) - if (_info.dtype == INFINI_DTYPE_F16) { if (_info.index_dtype == INFINI_DTYPE_I64) { DISPATCH_SPLITKV(int64_t, half, block_tables); @@ -1425,7 +1428,7 @@ infiniStatus_t Descriptor::calculate( return launch_prefill_warp( \ static_cast(out), static_cast(q), \ static_cast(k_cache), static_cast(v_cache), \ - static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + static_cast(block_tables), static_cast(total_kv_lens_i64), static_cast(cu_seqlens_q_i64), alibi_ptr, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.block_table_batch_stride, \ @@ -1438,7 +1441,7 @@ infiniStatus_t Descriptor::calculate( return launch_prefill( \ static_cast(out), static_cast(q), \ static_cast(k_cache), static_cast(v_cache), \ - static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + static_cast(block_tables), static_cast(total_kv_lens_i64), static_cast(cu_seqlens_q_i64), alibi_ptr, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.block_table_batch_stride, \ @@ -1451,7 +1454,7 @@ infiniStatus_t Descriptor::calculate( return launch_prefill_warpcta8( \ static_cast(out), static_cast(q), \ static_cast(k_cache), static_cast(v_cache), \ - static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + static_cast(block_tables), static_cast(total_kv_lens_i64), static_cast(cu_seqlens_q_i64), alibi_ptr, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.block_table_batch_stride, \ @@ -1464,7 +1467,7 @@ infiniStatus_t Descriptor::calculate( return launch_prefill_warpcta8pipe( \ static_cast(out), static_cast(q), \ static_cast(k_cache), static_cast(v_cache), \ - static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + static_cast(block_tables), static_cast(total_kv_lens_i64), static_cast(cu_seqlens_q_i64), alibi_ptr, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.block_table_batch_stride, \ @@ -1478,7 +1481,7 @@ infiniStatus_t Descriptor::calculate( return launch_prefill_warpcta8mma( \ static_cast(out), static_cast(q), \ static_cast(k_cache), static_cast(v_cache), \ - static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + static_cast(block_tables), static_cast(total_kv_lens_i64), static_cast(cu_seqlens_q_i64), alibi_ptr, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.block_table_batch_stride, \ @@ -1492,7 +1495,7 @@ infiniStatus_t Descriptor::calculate( return launch_prefill_warpcta8n128( \ static_cast(out), static_cast(q), \ static_cast(k_cache), static_cast(v_cache), \ - static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + static_cast(block_tables), static_cast(total_kv_lens_i64), static_cast(cu_seqlens_q_i64), alibi_ptr, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.block_table_batch_stride, \ @@ -1505,7 +1508,7 @@ infiniStatus_t Descriptor::calculate( return launch_prefill_warpcta16( \ static_cast(out), static_cast(q), \ static_cast(k_cache), static_cast(v_cache), \ - static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + static_cast(block_tables), static_cast(total_kv_lens_i64), static_cast(cu_seqlens_q_i64), alibi_ptr, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.block_table_batch_stride, \ @@ -1518,7 +1521,7 @@ infiniStatus_t Descriptor::calculate( return launch_prefill_ref( \ static_cast(out), static_cast(q), \ static_cast(k_cache), static_cast(v_cache), \ - static_cast(block_tables), total_kv_lens_i64, cu_seqlens_q_i64, alibi_ptr, \ + static_cast(block_tables), static_cast(total_kv_lens_i64), static_cast(cu_seqlens_q_i64), alibi_ptr, \ _info.num_heads, _info.num_seqs, _info.num_kv_heads, _info.total_q_tokens, \ _info.head_size, _info.scale, _info.max_num_blocks_per_seq, _info.page_block_size, \ _info.block_table_batch_stride, \ diff --git a/test/infinicore/ops/paged_attention_prefill.py b/test/infinicore/ops/paged_attention_prefill.py index a5a21cbf9..c05dc1e97 100644 --- a/test/infinicore/ops/paged_attention_prefill.py +++ b/test/infinicore/ops/paged_attention_prefill.py @@ -31,6 +31,8 @@ _TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16] +_INDEX_DTYPES = [infinicore.int32, infinicore.int64] + class SimpleCacheManager: def __init__(self, num_blocks, block_size): @@ -72,16 +74,16 @@ def parse_test_cases(): scale = head_size**-0.5 num_blocks = 8192 manager = SimpleCacheManager(num_blocks, block_size) - kv_lens = torch.zeros(num_seqs, dtype=torch.int64) + kv_lens = torch.zeros(num_seqs, dtype=torch.int32) persistent_k = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) persistent_v = torch.zeros((num_blocks, num_kv_heads, block_size, head_size)) for r in range(num_rounds): - q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int64) + q_lens = torch.randint(1, max_step_len + 1, (num_seqs,), dtype=torch.int32) kv_lens = kv_lens + q_lens total_q_tokens = q_lens.sum().item() - cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int64) + cum_seqlens_q = torch.zeros(num_seqs + 1, dtype=torch.int32) cum_seqlens_q[1:] = torch.cumsum(q_lens, dim=0) query_base = torch.randn((total_q_tokens, num_heads, head_size)) @@ -106,53 +108,53 @@ def parse_test_cases(): ) for dtype in _TENSOR_DTYPES: - tolerance = _TOLERANCE_MAP.get(dtype) - - test_cases.append( - TestCase( - inputs=[ - TensorSpec.from_tensor( - query_base.shape, - init_mode=TensorInitializer.MANUAL, - set_tensor=query_base.clone(), - dtype=dtype, - ), - TensorSpec.from_tensor( - persistent_k.shape, - init_mode=TensorInitializer.MANUAL, - set_tensor=persistent_k.clone(), - dtype=dtype, - ), - TensorSpec.from_tensor( - persistent_v.shape, - init_mode=TensorInitializer.MANUAL, - set_tensor=persistent_v.clone(), - dtype=dtype, - ), - TensorSpec.from_tensor( - padded_tables.shape, - init_mode=TensorInitializer.MANUAL, - set_tensor=padded_tables.clone(), - dtype=infinicore.int64, - ), - TensorSpec.from_tensor( - kv_lens.shape, - init_mode=TensorInitializer.MANUAL, - set_tensor=kv_lens.clone(), - dtype=infinicore.int64, - ), - TensorSpec.from_tensor( - cum_seqlens_q.shape, - init_mode=TensorInitializer.MANUAL, - set_tensor=cum_seqlens_q.clone(), - dtype=infinicore.int64, - ), - ], - kwargs={"scale": scale}, - tolerance=tolerance, - description=f"PagedAttentionPrefill_Round_{r}_{str(dtype).split('.')[-1]}", + for idx_dtype in _INDEX_DTYPES: # Loop through both I32 and I64 + tolerance = _TOLERANCE_MAP.get(dtype) + test_cases.append( + TestCase( + inputs=[ + TensorSpec.from_tensor( + query_base.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=query_base.clone(), + dtype=dtype, + ), + TensorSpec.from_tensor( + persistent_k.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=persistent_k.clone(), + dtype=dtype, + ), + TensorSpec.from_tensor( + persistent_v.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=persistent_v.clone(), + dtype=dtype, + ), + TensorSpec.from_tensor( + padded_tables.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=padded_tables.clone(), + dtype=idx_dtype, + ), + TensorSpec.from_tensor( + kv_lens.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=kv_lens.clone(), + dtype=idx_dtype, + ), + TensorSpec.from_tensor( + cum_seqlens_q.shape, + init_mode=TensorInitializer.MANUAL, + set_tensor=cum_seqlens_q.clone(), + dtype=idx_dtype, + ), + ], + kwargs={"scale": scale}, + tolerance=tolerance, + description=f"PagedAttentionPrefill_Round_{r}_{str(dtype).split('.')[-1]}", + ) ) - ) return test_cases diff --git a/test/infiniop/paged_attention_prefill.py b/test/infiniop/paged_attention_prefill.py index 65d843fae..82e850bc6 100644 --- a/test/infiniop/paged_attention_prefill.py +++ b/test/infiniop/paged_attention_prefill.py @@ -23,13 +23,20 @@ # Configuration (Internal Use Only) # ============================================================================== _TEST_CASES = [ - # num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds - (1, 1, 1, 128, 8, 16, 1), - (1, 4, 4, 128, 8, 16, 4), - (2, 8, 8, 128, 16, 32, 2), - (4, 16, 16, 128, 8, 64, 3), - (8, 64, 64, 128, 8, 16, 5), - (16, 128, 128, 128, 8, 16, 4), + # num_seqs, num_heads, num_kv_heads, head_size, block_size, max_step_len, num_rounds, index_dtypes + # index_dtype: The data type used for memory indexing of block_tables, cum_seq_lens and seq_lens + (1, 1, 1, 128, 8, 16, 1, InfiniDtype.I32), + (1, 1, 1, 128, 8, 16, 1, InfiniDtype.I64), + (1, 4, 4, 128, 8, 16, 4, InfiniDtype.I32), + (1, 4, 4, 128, 8, 16, 4, InfiniDtype.I64), + (2, 8, 8, 128, 16, 32, 2, InfiniDtype.I32), + (2, 8, 8, 128, 16, 32, 2, InfiniDtype.I64), + (4, 16, 16, 128, 8, 64, 3, InfiniDtype.I32), + (4, 16, 16, 128, 8, 64, 3, InfiniDtype.I64), + (8, 64, 64, 128, 8, 16, 5, InfiniDtype.I32), + (8, 64, 64, 128, 8, 16, 5, InfiniDtype.I64), + (16, 128, 128, 128, 8, 16, 4, InfiniDtype.I32), + (16, 128, 128, 128, 8, 16, 4, InfiniDtype.I64), ] _TENSOR_DTYPES = [InfiniDtype.BF16, InfiniDtype.F16] @@ -124,13 +131,15 @@ def test( block_size, max_step_len, num_rounds, + index_dtype=InfiniDtype.I64, dtype=InfiniDtype.F16, sync=None, ): print( f"Testing PagedAttentionPrefill on {InfiniDeviceNames[device]} with " f"seqs:{num_seqs}, heads:{num_heads}, head_size:{head_size}, " - f"block:{block_size}, max_step_len:{max_step_len}, num_rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}" + f"block:{block_size}, max_step_len:{max_step_len}, num_rounds:{num_rounds}, dtype:{InfiniDtypeNames[dtype]}, " + f"index_dtype:{InfiniDtypeNames[index_dtype]}" ) # 1. Initialize persistent resources @@ -194,23 +203,26 @@ def test( out = TestTensor.from_torch(q_packed_tensors, dtype, device) out.actual_tensor().zero_() + # 3. Referencing index_dtype to set torch dtype + torch_idx_type = torch.int32 if index_dtype == InfiniDtype.I32 else torch.int64 + seq_lens = TestTensor.from_torch( - torch.tensor(seq_lens_list, dtype=torch.int64), InfiniDtype.I64, device + torch.tensor(seq_lens_list, dtype=torch_idx_type), index_dtype, device ) cum_seq_lens_q = TestTensor.from_torch( - torch.tensor(cum_seq_lens_q_list, dtype=torch.int64), - InfiniDtype.I64, + torch.tensor(cum_seq_lens_q_list, dtype=torch_idx_type), + index_dtype, device, ) max_blocks = max(len(t) for t in all_block_tables) padded_tables = [t + [0] * (max_blocks - len(t)) for t in all_block_tables] block_tables = TestTensor.from_torch( - torch.tensor(padded_tables, dtype=torch.int64), InfiniDtype.I64, device + torch.tensor(padded_tables, dtype=torch_idx_type), index_dtype, device ) - # 3. Reference Calculation + # 4. Reference Calculation def torch_paged_attention_multi_turn(): return ref_paged_attention_multi_turn( q_new.torch_tensor(), @@ -224,7 +236,7 @@ def torch_paged_attention_multi_turn(): ans = torch_paged_attention_multi_turn() - # 4. Infiniop Operator Execution + # 5. Infiniop Operator Execution descriptor = infiniopOperatorDescriptor_t() check_error( LIBINFINIOP.infiniopCreatePagedAttentionPrefillDescriptor( @@ -272,7 +284,7 @@ def lib_attn(): if sync: sync() - # 5. Validation + # 6. Validation atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: debug(out.actual_tensor(), ans, atol=atol, rtol=rtol)