Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 38 additions & 37 deletions src/infiniop/ops/paged_attention_prefill/cuda/kernel_v2.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

namespace op::paged_attention_prefill::cuda {

__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *cu_seqlens_q, size_t num_seqs) {
template <typename Tindex>
__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const Tindex *cu_seqlens_q, size_t num_seqs) {
size_t low = 0, high = (num_seqs == 0) ? 0 : (num_seqs - 1);
while (low <= high) {
size_t mid = (low + high) >> 1;
Expand All @@ -43,8 +44,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
Expand Down Expand Up @@ -73,8 +74,8 @@ __device__ void PagedAttentionPrefillWarpKernel(
const int seq_idx = static_cast<int>(blockIdx.y);
const int q_token_local = static_cast<int>(blockIdx.z);

const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_token_local >= q_len) {
return;
Expand Down Expand Up @@ -256,8 +257,8 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_heads,
size_t num_seqs,
Expand Down Expand Up @@ -291,9 +292,9 @@ __global__ void PagedAttentionPrefillWarpGlobalKernel(
return;
}

const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs);
const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const size_t seq_idx = find_seq_id<Tindex>(global_token_idx, cu_seqlens_q_, num_seqs);
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);

const int q_token_local = static_cast<int>(global_token_idx - static_cast<size_t>(q_start));
Expand Down Expand Up @@ -477,8 +478,8 @@ __global__ void PagedAttentionPrefillReferenceKernel(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_heads,
size_t num_kv_heads,
Expand Down Expand Up @@ -506,7 +507,7 @@ __global__ void PagedAttentionPrefillReferenceKernel(
return;
}

const size_t seq_idx = find_seq_id(global_token_idx, cu_seqlens_q_, num_seqs);
const size_t seq_idx = find_seq_id<Tindex>(global_token_idx, cu_seqlens_q_, num_seqs);
const size_t q_token_idx = global_token_idx - static_cast<size_t>(cu_seqlens_q_[seq_idx]);
const size_t q_len = static_cast<size_t>(cu_seqlens_q_[seq_idx + 1] - cu_seqlens_q_[seq_idx]);

Expand Down Expand Up @@ -595,8 +596,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
Expand Down Expand Up @@ -632,8 +633,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernel(
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);

const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -865,8 +866,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
Expand Down Expand Up @@ -904,8 +905,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelined(
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);

const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -1312,8 +1313,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
Expand Down Expand Up @@ -1350,8 +1351,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelPipelinedSplitKv(
const int head_idx = static_cast<int>(blockIdx.x);
const int seq_idx = static_cast<int>(blockIdx.y);

const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -1778,8 +1779,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
const Tdata *k_cache_,
const Tdata *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
Expand Down Expand Up @@ -1815,8 +1816,8 @@ __device__ void PagedAttentionPrefillWarpCtaKernelKOnly(
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);

const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -2115,12 +2116,12 @@ __device__ __forceinline__ void PagedAttentionPrefillMmaScoreUpdateRow(
}
}

template <int kWarpSize, int kHeadDim, int kDimsPerThread>
template <typename Tindex, int kWarpSize, int kHeadDim, int kDimsPerThread>
__device__ __forceinline__ void PagedAttentionPrefillMmaScoreWriteRow(
int lane,
bool active,
int q_token_local,
int64_t q_start,
Tindex q_start,
int head_idx,
half *out_,
ptrdiff_t o_stride,
Expand Down Expand Up @@ -2153,8 +2154,8 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
const half *k_cache_,
const half *v_cache_,
const Tindex *block_tables_,
const int64_t *total_kv_lens_,
const int64_t *cu_seqlens_q_,
const Tindex *total_kv_lens_,
const Tindex *cu_seqlens_q_,
const float *alibi_slopes_,
size_t num_kv_heads,
float scale,
Expand Down Expand Up @@ -2198,8 +2199,8 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(
const int seq_idx = static_cast<int>(blockIdx.y);
const int m_block = static_cast<int>(blockIdx.z);

const int64_t q_start = cu_seqlens_q_[seq_idx];
const int64_t q_end = cu_seqlens_q_[seq_idx + 1];
const Tindex q_start = cu_seqlens_q_[seq_idx];
const Tindex q_end = cu_seqlens_q_[seq_idx + 1];
const int q_len = static_cast<int>(q_end - q_start);
if (q_len <= 0) {
return;
Expand Down Expand Up @@ -2353,11 +2354,11 @@ __device__ void PagedAttentionPrefillWarpCta8MmaHd128Kernel(

// Write outputs.
if (row0 < kBlockM) {
PagedAttentionPrefillMmaScoreWriteRow<kWarpSize, kHeadDim, kDimsPerThread>(
PagedAttentionPrefillMmaScoreWriteRow<Tindex, kWarpSize, kHeadDim, kDimsPerThread>(
lane, active0, m_start + row0, q_start, head_idx, out_, o_stride, o_head_stride, l0, acc0);
}
if (row1 < kBlockM) {
PagedAttentionPrefillMmaScoreWriteRow<kWarpSize, kHeadDim, kDimsPerThread>(
PagedAttentionPrefillMmaScoreWriteRow<Tindex, kWarpSize, kHeadDim, kDimsPerThread>(
lane, active1, m_start + row1, q_start, head_idx, out_, o_stride, o_head_stride, l1, acc1);
}
}
Expand Down
10 changes: 7 additions & 3 deletions src/infiniop/ops/paged_attention_prefill/info.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,13 @@ class PagedAttentionPrefillInfo {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

// Keep it simple: require total_kv_lens + cum_seqlens_q to be int64 for now
// (matches current paged_attention_prefill signature). We will convert to int32 internally later.
if (total_kv_lens_desc->dtype() != INFINI_DTYPE_I64 || cum_seqlens_q_desc->dtype() != INFINI_DTYPE_I64) {
// Index tensors use int32_t to match mainstream paged-attention implementations
// (e.g., vLLM / FlashAttention2). 32-bit indices needed, but now we also support int64_t.
if (!((total_kv_lens_desc->dtype() == INFINI_DTYPE_I64) || (total_kv_lens_desc->dtype() == INFINI_DTYPE_I32) || (total_kv_lens_desc->dtype() == INFINI_DTYPE_U32))) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

if (!((cum_seqlens_q_desc->dtype() == INFINI_DTYPE_I64) || (cum_seqlens_q_desc->dtype() == INFINI_DTYPE_I32) || (cum_seqlens_q_desc->dtype() == INFINI_DTYPE_U32))) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

Expand Down
Loading