✨ feat: add nanovllm.ops module with XAttention estimation kernels
Add ops module ported from tzj/minference branch containing: - xattn.py: XAttention block importance estimation with Triton kernels - xattn_estimate(): standard estimation for sparse attention mask - xattn_estimate_chunked(): chunked prefill compatible version - flat_group_gemm_fuse_reshape(): fused stride reshape + GEMM kernel - softmax_fuse_block_sum(): online softmax + block-wise sum kernel - chunked_attention.py: Flash attention with LSE output for chunk merging - test_xattn_estimate_chunked.py: verification test (all seq_lens pass) This prepares the foundation for AttentionPolicy refactoring where XAttentionPolicy.estimate() will call these ops. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
38
nanovllm/ops/__init__.py
Normal file
38
nanovllm/ops/__init__.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Operators module for nano-vLLM.
|
||||
|
||||
This module contains low-level attention operators and kernels.
|
||||
"""
|
||||
|
||||
from nanovllm.ops.chunked_attention import (
|
||||
flash_attn_with_lse,
|
||||
merge_attention_outputs,
|
||||
chunked_attention_varlen,
|
||||
ChunkedPrefillState,
|
||||
)
|
||||
|
||||
from nanovllm.ops.xattn import (
|
||||
xattn_estimate,
|
||||
xattn_estimate_chunked,
|
||||
flat_group_gemm_fuse_reshape,
|
||||
softmax_fuse_block_sum,
|
||||
find_blocks_chunked,
|
||||
create_causal_mask,
|
||||
compute_sparsity,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# chunked_attention
|
||||
"flash_attn_with_lse",
|
||||
"merge_attention_outputs",
|
||||
"chunked_attention_varlen",
|
||||
"ChunkedPrefillState",
|
||||
# xattn
|
||||
"xattn_estimate",
|
||||
"xattn_estimate_chunked",
|
||||
"flat_group_gemm_fuse_reshape",
|
||||
"softmax_fuse_block_sum",
|
||||
"find_blocks_chunked",
|
||||
"create_causal_mask",
|
||||
"compute_sparsity",
|
||||
]
|
||||
624
nanovllm/ops/chunked_attention.py
Normal file
624
nanovllm/ops/chunked_attention.py
Normal file
@@ -0,0 +1,624 @@
|
||||
"""
|
||||
Chunked attention implementation for CPU KV cache offloading.
|
||||
|
||||
This module implements flash attention with LSE (log-sum-exp) output,
|
||||
enabling proper online softmax merging for chunked prefill.
|
||||
|
||||
Key functions:
|
||||
- flash_attn_with_lse: Flash attention that returns output and LSE
|
||||
- merge_attention_outputs: Merge outputs from multiple KV chunks
|
||||
- chunked_prefill_attention: High-level interface for chunked attention
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
||||
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _fwd_kernel_with_lse(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
Out,
|
||||
Lse,
|
||||
softmax_scale,
|
||||
stride_qb,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_kb,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_vb,
|
||||
stride_vh,
|
||||
stride_vn,
|
||||
stride_ob,
|
||||
stride_oh,
|
||||
stride_om,
|
||||
nheads,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
seqlen_q_rounded,
|
||||
headdim,
|
||||
CACHE_KEY_SEQLEN_Q,
|
||||
CACHE_KEY_SEQLEN_K,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_HEADDIM: tl.constexpr,
|
||||
EVEN_M: tl.constexpr,
|
||||
EVEN_N: tl.constexpr,
|
||||
EVEN_HEADDIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Flash attention forward kernel with LSE output.
|
||||
|
||||
Implements standard Flash Attention online softmax algorithm:
|
||||
- m_i: running max of attention scores
|
||||
- l_i: running sum of exp(scores - m_i)
|
||||
- acc_o: running sum of softmax(scores) @ V (unnormalized)
|
||||
|
||||
Final output: acc_o / l_i
|
||||
Final LSE: m_i + log(l_i)
|
||||
"""
|
||||
start_m = tl.program_id(0)
|
||||
off_hb = tl.program_id(1)
|
||||
off_b = off_hb // nheads
|
||||
off_h = off_hb % nheads
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
||||
|
||||
# Pointers
|
||||
q_ptrs = (
|
||||
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
|
||||
)
|
||||
k_ptrs = (
|
||||
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
||||
)
|
||||
v_ptrs = (
|
||||
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
||||
)
|
||||
|
||||
# Initialize running statistics
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
|
||||
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized)
|
||||
|
||||
# Load Q (once per block)
|
||||
if EVEN_M & EVEN_N:
|
||||
if EVEN_HEADDIM:
|
||||
q = tl.load(q_ptrs)
|
||||
else:
|
||||
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
||||
else:
|
||||
q = tl.load(
|
||||
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
|
||||
)
|
||||
|
||||
# Loop over K, V blocks
|
||||
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
||||
for start_n in range(0, end_n, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
|
||||
# Load K
|
||||
if EVEN_N & EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kn,
|
||||
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kn,
|
||||
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Compute QK^T * scale
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
qk *= softmax_scale
|
||||
|
||||
# Apply masks
|
||||
if not EVEN_N:
|
||||
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
||||
|
||||
# Online softmax: compute block max
|
||||
m_ij = tl.max(qk, 1) # [BLOCK_M]
|
||||
|
||||
# New running max
|
||||
m_new = tl.maximum(m_i, m_ij) # [BLOCK_M]
|
||||
|
||||
# Rescale factor for previous accumulator
|
||||
alpha = tl.exp(m_i - m_new) # [BLOCK_M]
|
||||
|
||||
# Compute P = exp(qk - m_new)
|
||||
p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
|
||||
|
||||
# Sum of current block
|
||||
l_ij = tl.sum(p, 1) # [BLOCK_M]
|
||||
|
||||
# Update running sum: l_new = l_i * alpha + l_ij
|
||||
l_new = l_i * alpha + l_ij
|
||||
|
||||
# Rescale previous output and add new contribution
|
||||
acc_o = acc_o * alpha[:, None]
|
||||
|
||||
# Load V
|
||||
if EVEN_N & EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vn,
|
||||
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vn,
|
||||
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# acc_o += P @ V
|
||||
p = p.to(v.dtype)
|
||||
acc_o += tl.dot(p, v)
|
||||
|
||||
# Update running statistics
|
||||
m_i = m_new
|
||||
l_i = l_new
|
||||
|
||||
# Final normalization: output = acc_o / l_i
|
||||
acc_o = acc_o / l_i[:, None]
|
||||
|
||||
# Compute LSE = m_i + log(l_i)
|
||||
lse_i = m_i + tl.log(l_i)
|
||||
|
||||
# Store LSE
|
||||
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
||||
if EVEN_M:
|
||||
tl.store(lse_ptrs, lse_i)
|
||||
else:
|
||||
tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q)
|
||||
|
||||
# Store output
|
||||
out_ptrs = (
|
||||
Out
|
||||
+ off_b * stride_ob
|
||||
+ off_h * stride_oh
|
||||
+ (offs_m[:, None] * stride_om + offs_d[None, :])
|
||||
)
|
||||
if EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(out_ptrs, acc_o)
|
||||
else:
|
||||
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
||||
else:
|
||||
tl.store(
|
||||
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_with_lse(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Flash attention forward pass that returns both output and LSE.
|
||||
|
||||
Uses flash_attn library which natively supports GQA without memory overhead.
|
||||
|
||||
Args:
|
||||
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
|
||||
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
|
||||
causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
|
||||
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
|
||||
"""
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
|
||||
batch, seqlen_q, nheads_q, headdim = q.shape
|
||||
_, seqlen_k, nheads_kv, _ = k.shape
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = 1.0 / math.sqrt(headdim)
|
||||
|
||||
# Use flash_attn_func which natively supports GQA (no memory overhead)
|
||||
# It returns (output, softmax_lse) when return_attn_probs=True is not set
|
||||
# We need to use the internal function to get LSE
|
||||
out, lse, _ = flash_attn_func(
|
||||
q, k, v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
|
||||
)
|
||||
|
||||
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
|
||||
# Trim to actual seqlen_q
|
||||
lse = lse[:, :, :seqlen_q]
|
||||
|
||||
return out, lse
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _merge_lse_kernel(
|
||||
lse1_ptr, lse2_ptr, lse_out_ptr,
|
||||
num_elements: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Fused kernel for merging LSE values.
|
||||
|
||||
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
|
||||
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
|
||||
"""
|
||||
# Each program handles BLOCK_SIZE elements
|
||||
pid = tl.program_id(0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < num_elements
|
||||
|
||||
# Load lse values and convert to fp32 for precision
|
||||
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
|
||||
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
|
||||
|
||||
# Compute max for numerical stability (in fp32)
|
||||
max_lse = tl.maximum(lse1, lse2)
|
||||
|
||||
# Compute exp(lse - max_lse) in fp32
|
||||
exp1 = tl.exp(lse1 - max_lse)
|
||||
exp2 = tl.exp(lse2 - max_lse)
|
||||
|
||||
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
|
||||
lse_merged = max_lse + tl.log(exp1 + exp2)
|
||||
|
||||
# Store result (convert back to original dtype)
|
||||
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _merge_output_kernel(
|
||||
o1_ptr, o2_ptr, lse1_ptr, lse2_ptr, o_out_ptr,
|
||||
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Fused kernel for merging attention outputs.
|
||||
|
||||
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
|
||||
This is critical for numerical accuracy in chunked attention.
|
||||
"""
|
||||
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
|
||||
pid_batch = tl.program_id(0)
|
||||
pid_seq = tl.program_id(1)
|
||||
pid_head = tl.program_id(2)
|
||||
|
||||
# Compute LSE index: [batch, nheads, seqlen_q]
|
||||
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
|
||||
|
||||
# Load LSE values and convert to fp32 for precision
|
||||
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
|
||||
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
|
||||
|
||||
# Compute max and scaling factors in fp32
|
||||
max_lse = tl.maximum(lse1, lse2)
|
||||
exp1 = tl.exp(lse1 - max_lse)
|
||||
exp2 = tl.exp(lse2 - max_lse)
|
||||
sum_exp = exp1 + exp2
|
||||
|
||||
# Process headdim in chunks
|
||||
for d_offset in range(0, headdim, BLOCK_SIZE):
|
||||
d_idx = d_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = d_idx < headdim
|
||||
|
||||
# Compute output index: [batch, seqlen_q, nheads, headdim]
|
||||
base_idx = (pid_batch * seqlen_q * nheads * headdim +
|
||||
pid_seq * nheads * headdim +
|
||||
pid_head * headdim)
|
||||
o_idx = base_idx + d_idx
|
||||
|
||||
# Load o1, o2 and convert to fp32 for weighted sum
|
||||
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
|
||||
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
|
||||
|
||||
# Store result (Triton will convert back to original dtype)
|
||||
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
|
||||
|
||||
|
||||
def merge_attention_outputs(
|
||||
o1: torch.Tensor,
|
||||
lse1: torch.Tensor,
|
||||
o2: torch.Tensor,
|
||||
lse2: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Merge two attention outputs using online softmax (Triton fused kernel).
|
||||
|
||||
This implements the online softmax merging formula:
|
||||
- m_new = max(lse1, lse2)
|
||||
- o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new))
|
||||
- lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new))
|
||||
|
||||
Args:
|
||||
o1: First output [batch, seqlen_q, nheads, headdim]
|
||||
lse1: First LSE [batch, nheads, seqlen_q]
|
||||
o2: Second output [batch, seqlen_q, nheads, headdim]
|
||||
lse2: Second LSE [batch, nheads, seqlen_q]
|
||||
|
||||
Returns:
|
||||
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
|
||||
lse_merged: Merged LSE [batch, nheads, seqlen_q]
|
||||
"""
|
||||
batch, seqlen_q, nheads, headdim = o1.shape
|
||||
|
||||
# Allocate output tensors
|
||||
o_merged = torch.empty_like(o1)
|
||||
lse_merged = torch.empty_like(lse1)
|
||||
|
||||
# Launch LSE merge kernel
|
||||
num_lse_elements = batch * nheads * seqlen_q
|
||||
BLOCK_SIZE_LSE = 256
|
||||
grid_lse = (triton.cdiv(num_lse_elements, BLOCK_SIZE_LSE),)
|
||||
_merge_lse_kernel[grid_lse](
|
||||
lse1, lse2, lse_merged,
|
||||
num_lse_elements,
|
||||
BLOCK_SIZE=BLOCK_SIZE_LSE,
|
||||
)
|
||||
|
||||
# Launch output merge kernel
|
||||
BLOCK_SIZE = 128
|
||||
grid_output = (batch, seqlen_q, nheads)
|
||||
_merge_output_kernel[grid_output](
|
||||
o1, o2, lse1, lse2, o_merged,
|
||||
batch, seqlen_q, nheads, headdim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return o_merged, lse_merged
|
||||
|
||||
|
||||
def chunked_attention_varlen(
|
||||
q: torch.Tensor,
|
||||
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k_list: List[torch.Tensor],
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k_list: List[int],
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal_mask_per_chunk: Optional[List[bool]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with KV split across multiple chunks.
|
||||
|
||||
This is the core function for chunked prefill. It computes attention
|
||||
against each KV chunk and merges results using online softmax.
|
||||
|
||||
For causal attention with chunked KV:
|
||||
- First chunk (current tokens): Apply causal mask
|
||||
- Previous chunks: No causal mask (all previous tokens are valid context)
|
||||
|
||||
Args:
|
||||
q: Query tensor [total_q_tokens, nheads, headdim]
|
||||
kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim]
|
||||
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
||||
cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk
|
||||
max_seqlen_q: Maximum query sequence length
|
||||
max_seqlen_k_list: List of maximum key sequence lengths for each chunk
|
||||
softmax_scale: Scaling factor
|
||||
causal_mask_per_chunk: Whether to apply causal mask for each chunk
|
||||
|
||||
Returns:
|
||||
out: Output tensor [total_q_tokens, nheads, headdim]
|
||||
"""
|
||||
if len(kv_chunks) == 0:
|
||||
raise ValueError("Need at least one KV chunk")
|
||||
|
||||
nheads = q.shape[1]
|
||||
headdim = q.shape[2]
|
||||
batch = cu_seqlens_q.shape[0] - 1
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = 1.0 / math.sqrt(headdim)
|
||||
|
||||
if causal_mask_per_chunk is None:
|
||||
# Default: causal for last chunk only
|
||||
causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True]
|
||||
|
||||
# Initialize accumulated output and LSE
|
||||
accumulated_o = None
|
||||
accumulated_lse = None
|
||||
|
||||
for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
|
||||
is_causal = causal_mask_per_chunk[chunk_idx]
|
||||
|
||||
# Reshape Q for batch processing
|
||||
# For varlen, we need to handle each sequence separately
|
||||
# For simplicity, assume single sequence (batch=1) for now
|
||||
q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim]
|
||||
|
||||
# Compute attention for this chunk
|
||||
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_chunk,
|
||||
v_chunk,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if accumulated_o is None:
|
||||
accumulated_o = chunk_o
|
||||
accumulated_lse = chunk_lse
|
||||
else:
|
||||
accumulated_o, accumulated_lse = merge_attention_outputs(
|
||||
accumulated_o, accumulated_lse,
|
||||
chunk_o, chunk_lse,
|
||||
)
|
||||
|
||||
# Remove batch dimension
|
||||
return accumulated_o.squeeze(0)
|
||||
|
||||
|
||||
class ChunkedPrefillState:
|
||||
"""
|
||||
State for tracking chunked prefill progress.
|
||||
|
||||
This class maintains the accumulated attention output and LSE
|
||||
across multiple prefill chunks.
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device):
|
||||
self.num_layers = num_layers
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
# Per-layer accumulated outputs
|
||||
# Each entry: (accumulated_output, accumulated_lse) or None
|
||||
self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
|
||||
None for _ in range(num_layers)
|
||||
]
|
||||
|
||||
# Track which chunks have been processed
|
||||
self.processed_chunks: int = 0
|
||||
|
||||
def update_layer(
|
||||
self,
|
||||
layer_id: int,
|
||||
chunk_output: torch.Tensor,
|
||||
chunk_lse: torch.Tensor,
|
||||
):
|
||||
"""Update accumulated state for a layer with a new chunk's output."""
|
||||
if self.layer_states[layer_id] is None:
|
||||
self.layer_states[layer_id] = (chunk_output, chunk_lse)
|
||||
else:
|
||||
acc_o, acc_lse = self.layer_states[layer_id]
|
||||
merged_o, merged_lse = merge_attention_outputs(
|
||||
acc_o, acc_lse,
|
||||
chunk_output, chunk_lse,
|
||||
)
|
||||
self.layer_states[layer_id] = (merged_o, merged_lse)
|
||||
|
||||
def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]:
|
||||
"""Get the final accumulated output for a layer."""
|
||||
if self.layer_states[layer_id] is None:
|
||||
return None
|
||||
return self.layer_states[layer_id][0]
|
||||
|
||||
def clear(self):
|
||||
"""Clear all accumulated state."""
|
||||
self.layer_states = [None for _ in range(self.num_layers)]
|
||||
self.processed_chunks = 0
|
||||
|
||||
|
||||
# Test function
|
||||
def _test_chunked_attention():
|
||||
"""Test chunked attention using flash_attn_with_lse and merge_attention_outputs."""
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
print("=" * 70)
|
||||
print("Test: Chunked attention vs flash_attn_func (non-causal)")
|
||||
print("=" * 70)
|
||||
print("Splitting K,V into chunks, computing attention per chunk, then merging")
|
||||
print()
|
||||
|
||||
for dtype in [torch.float16, torch.bfloat16]:
|
||||
for num_chunks in [64, 128, 256]:
|
||||
for batch, seqlen, nheads, headdim in [
|
||||
(1, 1024, 32, 128),
|
||||
(1, 2048, 32, 128),
|
||||
(1, 4096, 32, 128),
|
||||
(1, 8192, 32, 128),
|
||||
]:
|
||||
# Generate random Q, K, V
|
||||
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
||||
|
||||
# Reference: full attention (non-causal)
|
||||
out_ref = flash_attn_func(q, k, v, causal=False)
|
||||
|
||||
# Chunked attention: split K, V into chunks
|
||||
chunk_size = seqlen // num_chunks
|
||||
accumulated_o = None
|
||||
accumulated_lse = None
|
||||
|
||||
for i in range(num_chunks):
|
||||
start = i * chunk_size
|
||||
end = (i + 1) * chunk_size
|
||||
|
||||
k_chunk = k[:, start:end, :, :]
|
||||
v_chunk = v[:, start:end, :, :]
|
||||
|
||||
# Q attends to this K,V chunk (non-causal)
|
||||
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||
q, k_chunk, v_chunk, causal=False
|
||||
)
|
||||
|
||||
if accumulated_o is None:
|
||||
accumulated_o = chunk_o
|
||||
accumulated_lse = chunk_lse
|
||||
else:
|
||||
# Merge with previous chunks
|
||||
accumulated_o, accumulated_lse = merge_attention_outputs(
|
||||
accumulated_o, accumulated_lse,
|
||||
chunk_o, chunk_lse
|
||||
)
|
||||
|
||||
# Compare
|
||||
out_diff = (out_ref - accumulated_o).abs()
|
||||
out_max_diff = out_diff.max().item()
|
||||
out_mean_diff = out_diff.mean().item()
|
||||
|
||||
status = "PASS" if out_max_diff < 1e-2 else "FAIL"
|
||||
print(
|
||||
f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} "
|
||||
f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) "
|
||||
f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}"
|
||||
)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print("Test completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_chunked_attention()
|
||||
1167
nanovllm/ops/xattn.py
Normal file
1167
nanovllm/ops/xattn.py
Normal file
File diff suppressed because it is too large
Load Diff
244
tests/test_xattn_estimate_chunked.py
Normal file
244
tests/test_xattn_estimate_chunked.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||
|
||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask
|
||||
as standard estimation. This ensures the chunked version can be used in
|
||||
chunked prefill scenarios without accuracy loss.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_estimate_chunked.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
# Configuration for xattn_estimate_chunked consistency test.
|
||||
# Key requirements for 100% match:
|
||||
# 1. Use matching chunk_size for both standard and chunked versions
|
||||
# 2. Use same random seed for reproducibility
|
||||
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
|
||||
# floating point precision in cumulative sum calculations.
|
||||
BLOCK_SIZE = 64
|
||||
STRIDE = 4
|
||||
THRESHOLD = 0.9
|
||||
CHUNK_SIZE = 4096 # External chunking size
|
||||
|
||||
# Test sequence lengths
|
||||
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# ============================================================
|
||||
|
||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||
"""Compare two masks and report differences."""
|
||||
if mask1.shape != mask2.shape:
|
||||
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||
return False
|
||||
|
||||
diff = (mask1 != mask2).sum().item()
|
||||
total = mask1.numel()
|
||||
match_rate = (total - diff) / total * 100
|
||||
|
||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||
|
||||
if diff > 0:
|
||||
diff_indices = torch.where(mask1 != mask2)
|
||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||
|
||||
return diff == 0
|
||||
|
||||
|
||||
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
|
||||
"""
|
||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||
This simulates how chunked prefill should be used in practice.
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query.shape
|
||||
_, _, k_len, _ = key.shape
|
||||
|
||||
q_block_num = (q_len + block_size - 1) // block_size
|
||||
k_block_num = (k_len + block_size - 1) // block_size
|
||||
|
||||
# If Q fits in one chunk, call directly
|
||||
if q_len <= chunk_size:
|
||||
return xattn_estimate_chunked(
|
||||
query, key,
|
||||
q_start_pos=0,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# External chunking: split Q and call for each chunk
|
||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||
print(f" External chunking: {num_q_chunks} chunks")
|
||||
|
||||
combined_attn_sum = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=query.dtype, device=query.device
|
||||
)
|
||||
combined_mask = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=torch.bool, device=query.device
|
||||
)
|
||||
|
||||
q_block_offset = 0
|
||||
for q_chunk_idx in range(num_q_chunks):
|
||||
q_chunk_start = q_chunk_idx * chunk_size
|
||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||
|
||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||
|
||||
# For causal attention, K accumulates up to current Q position
|
||||
# q_start_pos=0 means Q starts at position 0 in the full sequence
|
||||
# K is [0, q_chunk_end) for causal attention
|
||||
k_end = q_chunk_end
|
||||
k_chunk = key[:, :, :k_end, :]
|
||||
|
||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||
q_chunk, k_chunk,
|
||||
q_start_pos=q_chunk_start,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Place chunk results into combined output
|
||||
chunk_q_blocks = mask_chunk.shape[2]
|
||||
chunk_k_blocks = mask_chunk.shape[3]
|
||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||
q_block_offset += chunk_q_blocks
|
||||
|
||||
return combined_attn_sum, combined_mask
|
||||
|
||||
|
||||
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
|
||||
"""Test a single sequence length."""
|
||||
print(f"\nTesting seq_len={seq_len}")
|
||||
print("=" * 60)
|
||||
|
||||
# Generate random Q/K
|
||||
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run standard xattn_estimate
|
||||
print("[1] Running standard xattn_estimate...")
|
||||
try:
|
||||
attn_sum_std, mask_std = xattn_estimate(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
use_triton=True,
|
||||
causal=True,
|
||||
)
|
||||
density_std = mask_std.float().mean().item()
|
||||
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||
try:
|
||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
)
|
||||
density_chunked = mask_chunked.float().mean().item()
|
||||
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Compare results
|
||||
print("[3] Comparing results...")
|
||||
chunked_q_blocks = mask_chunked.shape[2]
|
||||
chunked_k_blocks = mask_chunked.shape[3]
|
||||
|
||||
# Extract comparable region from standard mask
|
||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
|
||||
# Compare masks
|
||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||
|
||||
# Compare attn_sums
|
||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||
else:
|
||||
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
|
||||
|
||||
# Clean up GPU memory
|
||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return masks_match
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("XAttention Chunked vs Standard Test")
|
||||
print("=" * 60)
|
||||
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
|
||||
print(f"External chunk_size={CHUNK_SIZE}")
|
||||
print()
|
||||
|
||||
# Check CUDA availability
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available!")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
print("✓ xattn_estimate imported")
|
||||
print("✓ xattn_estimate_chunked imported")
|
||||
|
||||
# Run tests
|
||||
all_passed = True
|
||||
results = []
|
||||
|
||||
for seq_len in TEST_SEQ_LENS:
|
||||
passed = test_single_seq_len(seq_len)
|
||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||
results.append((seq_len, chunks, passed))
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for seq_len, chunks, passed in results:
|
||||
status = "PASSED" if passed else "FAILED"
|
||||
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
|
||||
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("ALL TESTS PASSED!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("SOME TESTS FAILED!")
|
||||
sys.exit(1)
|
||||
Reference in New Issue
Block a user