✨ feat: add nanovllm.ops module with XAttention estimation kernels
Add ops module ported from tzj/minference branch containing: - xattn.py: XAttention block importance estimation with Triton kernels - xattn_estimate(): standard estimation for sparse attention mask - xattn_estimate_chunked(): chunked prefill compatible version - flat_group_gemm_fuse_reshape(): fused stride reshape + GEMM kernel - softmax_fuse_block_sum(): online softmax + block-wise sum kernel - chunked_attention.py: Flash attention with LSE output for chunk merging - test_xattn_estimate_chunked.py: verification test (all seq_lens pass) This prepares the foundation for AttentionPolicy refactoring where XAttentionPolicy.estimate() will call these ops. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
38
nanovllm/ops/__init__.py
Normal file
38
nanovllm/ops/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""
|
||||||
|
Operators module for nano-vLLM.
|
||||||
|
|
||||||
|
This module contains low-level attention operators and kernels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from nanovllm.ops.chunked_attention import (
|
||||||
|
flash_attn_with_lse,
|
||||||
|
merge_attention_outputs,
|
||||||
|
chunked_attention_varlen,
|
||||||
|
ChunkedPrefillState,
|
||||||
|
)
|
||||||
|
|
||||||
|
from nanovllm.ops.xattn import (
|
||||||
|
xattn_estimate,
|
||||||
|
xattn_estimate_chunked,
|
||||||
|
flat_group_gemm_fuse_reshape,
|
||||||
|
softmax_fuse_block_sum,
|
||||||
|
find_blocks_chunked,
|
||||||
|
create_causal_mask,
|
||||||
|
compute_sparsity,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# chunked_attention
|
||||||
|
"flash_attn_with_lse",
|
||||||
|
"merge_attention_outputs",
|
||||||
|
"chunked_attention_varlen",
|
||||||
|
"ChunkedPrefillState",
|
||||||
|
# xattn
|
||||||
|
"xattn_estimate",
|
||||||
|
"xattn_estimate_chunked",
|
||||||
|
"flat_group_gemm_fuse_reshape",
|
||||||
|
"softmax_fuse_block_sum",
|
||||||
|
"find_blocks_chunked",
|
||||||
|
"create_causal_mask",
|
||||||
|
"compute_sparsity",
|
||||||
|
]
|
||||||
624
nanovllm/ops/chunked_attention.py
Normal file
624
nanovllm/ops/chunked_attention.py
Normal file
@@ -0,0 +1,624 @@
|
|||||||
|
"""
|
||||||
|
Chunked attention implementation for CPU KV cache offloading.
|
||||||
|
|
||||||
|
This module implements flash attention with LSE (log-sum-exp) output,
|
||||||
|
enabling proper online softmax merging for chunked prefill.
|
||||||
|
|
||||||
|
Key functions:
|
||||||
|
- flash_attn_with_lse: Flash attention that returns output and LSE
|
||||||
|
- merge_attention_outputs: Merge outputs from multiple KV chunks
|
||||||
|
- chunked_prefill_attention: High-level interface for chunked attention
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
from typing import Tuple, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@triton.heuristics(
|
||||||
|
{
|
||||||
|
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||||
|
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
||||||
|
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def _fwd_kernel_with_lse(
|
||||||
|
Q,
|
||||||
|
K,
|
||||||
|
V,
|
||||||
|
Out,
|
||||||
|
Lse,
|
||||||
|
softmax_scale,
|
||||||
|
stride_qb,
|
||||||
|
stride_qh,
|
||||||
|
stride_qm,
|
||||||
|
stride_kb,
|
||||||
|
stride_kh,
|
||||||
|
stride_kn,
|
||||||
|
stride_vb,
|
||||||
|
stride_vh,
|
||||||
|
stride_vn,
|
||||||
|
stride_ob,
|
||||||
|
stride_oh,
|
||||||
|
stride_om,
|
||||||
|
nheads,
|
||||||
|
seqlen_q,
|
||||||
|
seqlen_k,
|
||||||
|
seqlen_q_rounded,
|
||||||
|
headdim,
|
||||||
|
CACHE_KEY_SEQLEN_Q,
|
||||||
|
CACHE_KEY_SEQLEN_K,
|
||||||
|
IS_CAUSAL: tl.constexpr,
|
||||||
|
BLOCK_HEADDIM: tl.constexpr,
|
||||||
|
EVEN_M: tl.constexpr,
|
||||||
|
EVEN_N: tl.constexpr,
|
||||||
|
EVEN_HEADDIM: tl.constexpr,
|
||||||
|
BLOCK_M: tl.constexpr,
|
||||||
|
BLOCK_N: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Flash attention forward kernel with LSE output.
|
||||||
|
|
||||||
|
Implements standard Flash Attention online softmax algorithm:
|
||||||
|
- m_i: running max of attention scores
|
||||||
|
- l_i: running sum of exp(scores - m_i)
|
||||||
|
- acc_o: running sum of softmax(scores) @ V (unnormalized)
|
||||||
|
|
||||||
|
Final output: acc_o / l_i
|
||||||
|
Final LSE: m_i + log(l_i)
|
||||||
|
"""
|
||||||
|
start_m = tl.program_id(0)
|
||||||
|
off_hb = tl.program_id(1)
|
||||||
|
off_b = off_hb // nheads
|
||||||
|
off_h = off_hb % nheads
|
||||||
|
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||||
|
offs_n = tl.arange(0, BLOCK_N)
|
||||||
|
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
||||||
|
|
||||||
|
# Pointers
|
||||||
|
q_ptrs = (
|
||||||
|
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
|
||||||
|
)
|
||||||
|
k_ptrs = (
|
||||||
|
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
||||||
|
)
|
||||||
|
v_ptrs = (
|
||||||
|
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize running statistics
|
||||||
|
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max
|
||||||
|
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
|
||||||
|
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized)
|
||||||
|
|
||||||
|
# Load Q (once per block)
|
||||||
|
if EVEN_M & EVEN_N:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
q = tl.load(q_ptrs)
|
||||||
|
else:
|
||||||
|
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
||||||
|
else:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
||||||
|
else:
|
||||||
|
q = tl.load(
|
||||||
|
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Loop over K, V blocks
|
||||||
|
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
||||||
|
for start_n in range(0, end_n, BLOCK_N):
|
||||||
|
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||||
|
|
||||||
|
# Load K
|
||||||
|
if EVEN_N & EVEN_M:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||||
|
else:
|
||||||
|
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
|
||||||
|
else:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
k = tl.load(
|
||||||
|
k_ptrs + start_n * stride_kn,
|
||||||
|
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
k = tl.load(
|
||||||
|
k_ptrs + start_n * stride_kn,
|
||||||
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute QK^T * scale
|
||||||
|
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||||
|
qk += tl.dot(q, tl.trans(k))
|
||||||
|
qk *= softmax_scale
|
||||||
|
|
||||||
|
# Apply masks
|
||||||
|
if not EVEN_N:
|
||||||
|
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
||||||
|
if IS_CAUSAL:
|
||||||
|
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
||||||
|
|
||||||
|
# Online softmax: compute block max
|
||||||
|
m_ij = tl.max(qk, 1) # [BLOCK_M]
|
||||||
|
|
||||||
|
# New running max
|
||||||
|
m_new = tl.maximum(m_i, m_ij) # [BLOCK_M]
|
||||||
|
|
||||||
|
# Rescale factor for previous accumulator
|
||||||
|
alpha = tl.exp(m_i - m_new) # [BLOCK_M]
|
||||||
|
|
||||||
|
# Compute P = exp(qk - m_new)
|
||||||
|
p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
|
||||||
|
|
||||||
|
# Sum of current block
|
||||||
|
l_ij = tl.sum(p, 1) # [BLOCK_M]
|
||||||
|
|
||||||
|
# Update running sum: l_new = l_i * alpha + l_ij
|
||||||
|
l_new = l_i * alpha + l_ij
|
||||||
|
|
||||||
|
# Rescale previous output and add new contribution
|
||||||
|
acc_o = acc_o * alpha[:, None]
|
||||||
|
|
||||||
|
# Load V
|
||||||
|
if EVEN_N & EVEN_M:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
v = tl.load(v_ptrs + start_n * stride_vn)
|
||||||
|
else:
|
||||||
|
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
|
||||||
|
else:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
v = tl.load(
|
||||||
|
v_ptrs + start_n * stride_vn,
|
||||||
|
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
v = tl.load(
|
||||||
|
v_ptrs + start_n * stride_vn,
|
||||||
|
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||||
|
other=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# acc_o += P @ V
|
||||||
|
p = p.to(v.dtype)
|
||||||
|
acc_o += tl.dot(p, v)
|
||||||
|
|
||||||
|
# Update running statistics
|
||||||
|
m_i = m_new
|
||||||
|
l_i = l_new
|
||||||
|
|
||||||
|
# Final normalization: output = acc_o / l_i
|
||||||
|
acc_o = acc_o / l_i[:, None]
|
||||||
|
|
||||||
|
# Compute LSE = m_i + log(l_i)
|
||||||
|
lse_i = m_i + tl.log(l_i)
|
||||||
|
|
||||||
|
# Store LSE
|
||||||
|
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
||||||
|
if EVEN_M:
|
||||||
|
tl.store(lse_ptrs, lse_i)
|
||||||
|
else:
|
||||||
|
tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q)
|
||||||
|
|
||||||
|
# Store output
|
||||||
|
out_ptrs = (
|
||||||
|
Out
|
||||||
|
+ off_b * stride_ob
|
||||||
|
+ off_h * stride_oh
|
||||||
|
+ (offs_m[:, None] * stride_om + offs_d[None, :])
|
||||||
|
)
|
||||||
|
if EVEN_M:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
tl.store(out_ptrs, acc_o)
|
||||||
|
else:
|
||||||
|
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
|
||||||
|
else:
|
||||||
|
if EVEN_HEADDIM:
|
||||||
|
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
||||||
|
else:
|
||||||
|
tl.store(
|
||||||
|
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def flash_attn_with_lse(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
causal: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Flash attention forward pass that returns both output and LSE.
|
||||||
|
|
||||||
|
Uses flash_attn library which natively supports GQA without memory overhead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
|
||||||
|
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||||
|
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||||
|
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
|
||||||
|
causal: Whether to apply causal masking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
|
||||||
|
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
|
||||||
|
"""
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
|
|
||||||
|
batch, seqlen_q, nheads_q, headdim = q.shape
|
||||||
|
_, seqlen_k, nheads_kv, _ = k.shape
|
||||||
|
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = 1.0 / math.sqrt(headdim)
|
||||||
|
|
||||||
|
# Use flash_attn_func which natively supports GQA (no memory overhead)
|
||||||
|
# It returns (output, softmax_lse) when return_attn_probs=True is not set
|
||||||
|
# We need to use the internal function to get LSE
|
||||||
|
out, lse, _ = flash_attn_func(
|
||||||
|
q, k, v,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
|
||||||
|
)
|
||||||
|
|
||||||
|
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
|
||||||
|
# Trim to actual seqlen_q
|
||||||
|
lse = lse[:, :, :seqlen_q]
|
||||||
|
|
||||||
|
return out, lse
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _merge_lse_kernel(
|
||||||
|
lse1_ptr, lse2_ptr, lse_out_ptr,
|
||||||
|
num_elements: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Fused kernel for merging LSE values.
|
||||||
|
|
||||||
|
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
|
||||||
|
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
|
||||||
|
"""
|
||||||
|
# Each program handles BLOCK_SIZE elements
|
||||||
|
pid = tl.program_id(0)
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < num_elements
|
||||||
|
|
||||||
|
# Load lse values and convert to fp32 for precision
|
||||||
|
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
|
||||||
|
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
|
||||||
|
|
||||||
|
# Compute max for numerical stability (in fp32)
|
||||||
|
max_lse = tl.maximum(lse1, lse2)
|
||||||
|
|
||||||
|
# Compute exp(lse - max_lse) in fp32
|
||||||
|
exp1 = tl.exp(lse1 - max_lse)
|
||||||
|
exp2 = tl.exp(lse2 - max_lse)
|
||||||
|
|
||||||
|
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
|
||||||
|
lse_merged = max_lse + tl.log(exp1 + exp2)
|
||||||
|
|
||||||
|
# Store result (convert back to original dtype)
|
||||||
|
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _merge_output_kernel(
|
||||||
|
o1_ptr, o2_ptr, lse1_ptr, lse2_ptr, o_out_ptr,
|
||||||
|
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
|
||||||
|
BLOCK_SIZE: tl.constexpr,
|
||||||
|
):
|
||||||
|
"""Fused kernel for merging attention outputs.
|
||||||
|
|
||||||
|
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
|
||||||
|
This is critical for numerical accuracy in chunked attention.
|
||||||
|
"""
|
||||||
|
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
|
||||||
|
pid_batch = tl.program_id(0)
|
||||||
|
pid_seq = tl.program_id(1)
|
||||||
|
pid_head = tl.program_id(2)
|
||||||
|
|
||||||
|
# Compute LSE index: [batch, nheads, seqlen_q]
|
||||||
|
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
|
||||||
|
|
||||||
|
# Load LSE values and convert to fp32 for precision
|
||||||
|
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
|
||||||
|
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
|
||||||
|
|
||||||
|
# Compute max and scaling factors in fp32
|
||||||
|
max_lse = tl.maximum(lse1, lse2)
|
||||||
|
exp1 = tl.exp(lse1 - max_lse)
|
||||||
|
exp2 = tl.exp(lse2 - max_lse)
|
||||||
|
sum_exp = exp1 + exp2
|
||||||
|
|
||||||
|
# Process headdim in chunks
|
||||||
|
for d_offset in range(0, headdim, BLOCK_SIZE):
|
||||||
|
d_idx = d_offset + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = d_idx < headdim
|
||||||
|
|
||||||
|
# Compute output index: [batch, seqlen_q, nheads, headdim]
|
||||||
|
base_idx = (pid_batch * seqlen_q * nheads * headdim +
|
||||||
|
pid_seq * nheads * headdim +
|
||||||
|
pid_head * headdim)
|
||||||
|
o_idx = base_idx + d_idx
|
||||||
|
|
||||||
|
# Load o1, o2 and convert to fp32 for weighted sum
|
||||||
|
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||||
|
|
||||||
|
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
|
||||||
|
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
|
||||||
|
|
||||||
|
# Store result (Triton will convert back to original dtype)
|
||||||
|
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
def merge_attention_outputs(
|
||||||
|
o1: torch.Tensor,
|
||||||
|
lse1: torch.Tensor,
|
||||||
|
o2: torch.Tensor,
|
||||||
|
lse2: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Merge two attention outputs using online softmax (Triton fused kernel).
|
||||||
|
|
||||||
|
This implements the online softmax merging formula:
|
||||||
|
- m_new = max(lse1, lse2)
|
||||||
|
- o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new))
|
||||||
|
- lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new))
|
||||||
|
|
||||||
|
Args:
|
||||||
|
o1: First output [batch, seqlen_q, nheads, headdim]
|
||||||
|
lse1: First LSE [batch, nheads, seqlen_q]
|
||||||
|
o2: Second output [batch, seqlen_q, nheads, headdim]
|
||||||
|
lse2: Second LSE [batch, nheads, seqlen_q]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
|
||||||
|
lse_merged: Merged LSE [batch, nheads, seqlen_q]
|
||||||
|
"""
|
||||||
|
batch, seqlen_q, nheads, headdim = o1.shape
|
||||||
|
|
||||||
|
# Allocate output tensors
|
||||||
|
o_merged = torch.empty_like(o1)
|
||||||
|
lse_merged = torch.empty_like(lse1)
|
||||||
|
|
||||||
|
# Launch LSE merge kernel
|
||||||
|
num_lse_elements = batch * nheads * seqlen_q
|
||||||
|
BLOCK_SIZE_LSE = 256
|
||||||
|
grid_lse = (triton.cdiv(num_lse_elements, BLOCK_SIZE_LSE),)
|
||||||
|
_merge_lse_kernel[grid_lse](
|
||||||
|
lse1, lse2, lse_merged,
|
||||||
|
num_lse_elements,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE_LSE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Launch output merge kernel
|
||||||
|
BLOCK_SIZE = 128
|
||||||
|
grid_output = (batch, seqlen_q, nheads)
|
||||||
|
_merge_output_kernel[grid_output](
|
||||||
|
o1, o2, lse1, lse2, o_merged,
|
||||||
|
batch, seqlen_q, nheads, headdim,
|
||||||
|
BLOCK_SIZE=BLOCK_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o_merged, lse_merged
|
||||||
|
|
||||||
|
|
||||||
|
def chunked_attention_varlen(
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
cu_seqlens_k_list: List[torch.Tensor],
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k_list: List[int],
|
||||||
|
softmax_scale: Optional[float] = None,
|
||||||
|
causal_mask_per_chunk: Optional[List[bool]] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute attention with KV split across multiple chunks.
|
||||||
|
|
||||||
|
This is the core function for chunked prefill. It computes attention
|
||||||
|
against each KV chunk and merges results using online softmax.
|
||||||
|
|
||||||
|
For causal attention with chunked KV:
|
||||||
|
- First chunk (current tokens): Apply causal mask
|
||||||
|
- Previous chunks: No causal mask (all previous tokens are valid context)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [total_q_tokens, nheads, headdim]
|
||||||
|
kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim]
|
||||||
|
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
||||||
|
cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk
|
||||||
|
max_seqlen_q: Maximum query sequence length
|
||||||
|
max_seqlen_k_list: List of maximum key sequence lengths for each chunk
|
||||||
|
softmax_scale: Scaling factor
|
||||||
|
causal_mask_per_chunk: Whether to apply causal mask for each chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out: Output tensor [total_q_tokens, nheads, headdim]
|
||||||
|
"""
|
||||||
|
if len(kv_chunks) == 0:
|
||||||
|
raise ValueError("Need at least one KV chunk")
|
||||||
|
|
||||||
|
nheads = q.shape[1]
|
||||||
|
headdim = q.shape[2]
|
||||||
|
batch = cu_seqlens_q.shape[0] - 1
|
||||||
|
|
||||||
|
if softmax_scale is None:
|
||||||
|
softmax_scale = 1.0 / math.sqrt(headdim)
|
||||||
|
|
||||||
|
if causal_mask_per_chunk is None:
|
||||||
|
# Default: causal for last chunk only
|
||||||
|
causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True]
|
||||||
|
|
||||||
|
# Initialize accumulated output and LSE
|
||||||
|
accumulated_o = None
|
||||||
|
accumulated_lse = None
|
||||||
|
|
||||||
|
for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
|
||||||
|
is_causal = causal_mask_per_chunk[chunk_idx]
|
||||||
|
|
||||||
|
# Reshape Q for batch processing
|
||||||
|
# For varlen, we need to handle each sequence separately
|
||||||
|
# For simplicity, assume single sequence (batch=1) for now
|
||||||
|
q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim]
|
||||||
|
|
||||||
|
# Compute attention for this chunk
|
||||||
|
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||||
|
q_batched,
|
||||||
|
k_chunk,
|
||||||
|
v_chunk,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=is_causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge with accumulated
|
||||||
|
if accumulated_o is None:
|
||||||
|
accumulated_o = chunk_o
|
||||||
|
accumulated_lse = chunk_lse
|
||||||
|
else:
|
||||||
|
accumulated_o, accumulated_lse = merge_attention_outputs(
|
||||||
|
accumulated_o, accumulated_lse,
|
||||||
|
chunk_o, chunk_lse,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Remove batch dimension
|
||||||
|
return accumulated_o.squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkedPrefillState:
|
||||||
|
"""
|
||||||
|
State for tracking chunked prefill progress.
|
||||||
|
|
||||||
|
This class maintains the accumulated attention output and LSE
|
||||||
|
across multiple prefill chunks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device):
|
||||||
|
self.num_layers = num_layers
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Per-layer accumulated outputs
|
||||||
|
# Each entry: (accumulated_output, accumulated_lse) or None
|
||||||
|
self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
|
||||||
|
None for _ in range(num_layers)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Track which chunks have been processed
|
||||||
|
self.processed_chunks: int = 0
|
||||||
|
|
||||||
|
def update_layer(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
chunk_output: torch.Tensor,
|
||||||
|
chunk_lse: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Update accumulated state for a layer with a new chunk's output."""
|
||||||
|
if self.layer_states[layer_id] is None:
|
||||||
|
self.layer_states[layer_id] = (chunk_output, chunk_lse)
|
||||||
|
else:
|
||||||
|
acc_o, acc_lse = self.layer_states[layer_id]
|
||||||
|
merged_o, merged_lse = merge_attention_outputs(
|
||||||
|
acc_o, acc_lse,
|
||||||
|
chunk_output, chunk_lse,
|
||||||
|
)
|
||||||
|
self.layer_states[layer_id] = (merged_o, merged_lse)
|
||||||
|
|
||||||
|
def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]:
|
||||||
|
"""Get the final accumulated output for a layer."""
|
||||||
|
if self.layer_states[layer_id] is None:
|
||||||
|
return None
|
||||||
|
return self.layer_states[layer_id][0]
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""Clear all accumulated state."""
|
||||||
|
self.layer_states = [None for _ in range(self.num_layers)]
|
||||||
|
self.processed_chunks = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Test function
|
||||||
|
def _test_chunked_attention():
|
||||||
|
"""Test chunked attention using flash_attn_with_lse and merge_attention_outputs."""
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_func
|
||||||
|
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("Test: Chunked attention vs flash_attn_func (non-causal)")
|
||||||
|
print("=" * 70)
|
||||||
|
print("Splitting K,V into chunks, computing attention per chunk, then merging")
|
||||||
|
print()
|
||||||
|
|
||||||
|
for dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
for num_chunks in [64, 128, 256]:
|
||||||
|
for batch, seqlen, nheads, headdim in [
|
||||||
|
(1, 1024, 32, 128),
|
||||||
|
(1, 2048, 32, 128),
|
||||||
|
(1, 4096, 32, 128),
|
||||||
|
(1, 8192, 32, 128),
|
||||||
|
]:
|
||||||
|
# Generate random Q, K, V
|
||||||
|
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||||
|
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||||
|
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||||
|
|
||||||
|
# Reference: full attention (non-causal)
|
||||||
|
out_ref = flash_attn_func(q, k, v, causal=False)
|
||||||
|
|
||||||
|
# Chunked attention: split K, V into chunks
|
||||||
|
chunk_size = seqlen // num_chunks
|
||||||
|
accumulated_o = None
|
||||||
|
accumulated_lse = None
|
||||||
|
|
||||||
|
for i in range(num_chunks):
|
||||||
|
start = i * chunk_size
|
||||||
|
end = (i + 1) * chunk_size
|
||||||
|
|
||||||
|
k_chunk = k[:, start:end, :, :]
|
||||||
|
v_chunk = v[:, start:end, :, :]
|
||||||
|
|
||||||
|
# Q attends to this K,V chunk (non-causal)
|
||||||
|
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||||
|
q, k_chunk, v_chunk, causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if accumulated_o is None:
|
||||||
|
accumulated_o = chunk_o
|
||||||
|
accumulated_lse = chunk_lse
|
||||||
|
else:
|
||||||
|
# Merge with previous chunks
|
||||||
|
accumulated_o, accumulated_lse = merge_attention_outputs(
|
||||||
|
accumulated_o, accumulated_lse,
|
||||||
|
chunk_o, chunk_lse
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare
|
||||||
|
out_diff = (out_ref - accumulated_o).abs()
|
||||||
|
out_max_diff = out_diff.max().item()
|
||||||
|
out_mean_diff = out_diff.mean().item()
|
||||||
|
|
||||||
|
status = "PASS" if out_max_diff < 1e-2 else "FAIL"
|
||||||
|
print(
|
||||||
|
f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} "
|
||||||
|
f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) "
|
||||||
|
f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("=" * 70)
|
||||||
|
print("Test completed!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test_chunked_attention()
|
||||||
1167
nanovllm/ops/xattn.py
Normal file
1167
nanovllm/ops/xattn.py
Normal file
File diff suppressed because it is too large
Load Diff
244
tests/test_xattn_estimate_chunked.py
Normal file
244
tests/test_xattn_estimate_chunked.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
"""
|
||||||
|
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||||
|
|
||||||
|
Verify that chunked estimation with EXTERNAL chunking produces the same mask
|
||||||
|
as standard estimation. This ensures the chunked version can be used in
|
||||||
|
chunked prefill scenarios without accuracy loss.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||||
|
python tests/test_xattn_estimate_chunked.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
import torch
|
||||||
|
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Configuration
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Configuration for xattn_estimate_chunked consistency test.
|
||||||
|
# Key requirements for 100% match:
|
||||||
|
# 1. Use matching chunk_size for both standard and chunked versions
|
||||||
|
# 2. Use same random seed for reproducibility
|
||||||
|
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
|
||||||
|
# floating point precision in cumulative sum calculations.
|
||||||
|
BLOCK_SIZE = 64
|
||||||
|
STRIDE = 4
|
||||||
|
THRESHOLD = 0.9
|
||||||
|
CHUNK_SIZE = 4096 # External chunking size
|
||||||
|
|
||||||
|
# Test sequence lengths
|
||||||
|
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Utility Functions
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||||
|
"""Compare two masks and report differences."""
|
||||||
|
if mask1.shape != mask2.shape:
|
||||||
|
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
diff = (mask1 != mask2).sum().item()
|
||||||
|
total = mask1.numel()
|
||||||
|
match_rate = (total - diff) / total * 100
|
||||||
|
|
||||||
|
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||||
|
|
||||||
|
if diff > 0:
|
||||||
|
diff_indices = torch.where(mask1 != mask2)
|
||||||
|
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||||
|
|
||||||
|
return diff == 0
|
||||||
|
|
||||||
|
|
||||||
|
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
|
||||||
|
"""
|
||||||
|
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||||
|
This simulates how chunked prefill should be used in practice.
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, q_len, head_dim = query.shape
|
||||||
|
_, _, k_len, _ = key.shape
|
||||||
|
|
||||||
|
q_block_num = (q_len + block_size - 1) // block_size
|
||||||
|
k_block_num = (k_len + block_size - 1) // block_size
|
||||||
|
|
||||||
|
# If Q fits in one chunk, call directly
|
||||||
|
if q_len <= chunk_size:
|
||||||
|
return xattn_estimate_chunked(
|
||||||
|
query, key,
|
||||||
|
q_start_pos=0,
|
||||||
|
block_size=block_size,
|
||||||
|
stride=stride,
|
||||||
|
threshold=threshold,
|
||||||
|
use_triton=True,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# External chunking: split Q and call for each chunk
|
||||||
|
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||||
|
print(f" External chunking: {num_q_chunks} chunks")
|
||||||
|
|
||||||
|
combined_attn_sum = torch.zeros(
|
||||||
|
batch_size, num_heads, q_block_num, k_block_num,
|
||||||
|
dtype=query.dtype, device=query.device
|
||||||
|
)
|
||||||
|
combined_mask = torch.zeros(
|
||||||
|
batch_size, num_heads, q_block_num, k_block_num,
|
||||||
|
dtype=torch.bool, device=query.device
|
||||||
|
)
|
||||||
|
|
||||||
|
q_block_offset = 0
|
||||||
|
for q_chunk_idx in range(num_q_chunks):
|
||||||
|
q_chunk_start = q_chunk_idx * chunk_size
|
||||||
|
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||||
|
|
||||||
|
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||||
|
|
||||||
|
# For causal attention, K accumulates up to current Q position
|
||||||
|
# q_start_pos=0 means Q starts at position 0 in the full sequence
|
||||||
|
# K is [0, q_chunk_end) for causal attention
|
||||||
|
k_end = q_chunk_end
|
||||||
|
k_chunk = key[:, :, :k_end, :]
|
||||||
|
|
||||||
|
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||||
|
q_chunk, k_chunk,
|
||||||
|
q_start_pos=q_chunk_start,
|
||||||
|
block_size=block_size,
|
||||||
|
stride=stride,
|
||||||
|
threshold=threshold,
|
||||||
|
use_triton=True,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Place chunk results into combined output
|
||||||
|
chunk_q_blocks = mask_chunk.shape[2]
|
||||||
|
chunk_k_blocks = mask_chunk.shape[3]
|
||||||
|
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||||
|
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||||
|
q_block_offset += chunk_q_blocks
|
||||||
|
|
||||||
|
return combined_attn_sum, combined_mask
|
||||||
|
|
||||||
|
|
||||||
|
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
|
||||||
|
"""Test a single sequence length."""
|
||||||
|
print(f"\nTesting seq_len={seq_len}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Generate random Q/K
|
||||||
|
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||||
|
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
# Run standard xattn_estimate
|
||||||
|
print("[1] Running standard xattn_estimate...")
|
||||||
|
try:
|
||||||
|
attn_sum_std, mask_std = xattn_estimate(
|
||||||
|
query, key,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
stride=STRIDE,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
use_triton=True,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
density_std = mask_std.float().mean().item()
|
||||||
|
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ERROR: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||||
|
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||||
|
try:
|
||||||
|
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||||
|
query, key,
|
||||||
|
block_size=BLOCK_SIZE,
|
||||||
|
stride=STRIDE,
|
||||||
|
threshold=THRESHOLD,
|
||||||
|
chunk_size=CHUNK_SIZE,
|
||||||
|
)
|
||||||
|
density_chunked = mask_chunked.float().mean().item()
|
||||||
|
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" ERROR: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Compare results
|
||||||
|
print("[3] Comparing results...")
|
||||||
|
chunked_q_blocks = mask_chunked.shape[2]
|
||||||
|
chunked_k_blocks = mask_chunked.shape[3]
|
||||||
|
|
||||||
|
# Extract comparable region from standard mask
|
||||||
|
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||||
|
|
||||||
|
# Compare masks
|
||||||
|
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||||
|
|
||||||
|
# Compare attn_sums
|
||||||
|
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||||
|
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||||
|
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||||
|
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||||
|
else:
|
||||||
|
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
|
||||||
|
|
||||||
|
# Clean up GPU memory
|
||||||
|
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
return masks_match
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Main Test
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("XAttention Chunked vs Standard Test")
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
|
||||||
|
print(f"External chunk_size={CHUNK_SIZE}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
# Check CUDA availability
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available!")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||||
|
print("✓ xattn_estimate imported")
|
||||||
|
print("✓ xattn_estimate_chunked imported")
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
all_passed = True
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for seq_len in TEST_SEQ_LENS:
|
||||||
|
passed = test_single_seq_len(seq_len)
|
||||||
|
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||||
|
results.append((seq_len, chunks, passed))
|
||||||
|
if not passed:
|
||||||
|
all_passed = False
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("SUMMARY")
|
||||||
|
print("=" * 60)
|
||||||
|
for seq_len, chunks, passed in results:
|
||||||
|
status = "PASSED" if passed else "FAILED"
|
||||||
|
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
if all_passed:
|
||||||
|
print("ALL TESTS PASSED!")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print("SOME TESTS FAILED!")
|
||||||
|
sys.exit(1)
|
||||||
Reference in New Issue
Block a user