Files
nano-vllm/nanovllm/kvcache/sparse/xattn.py
2026-01-22 22:20:34 +08:00

311 lines
11 KiB
Python

"""
XAttention sparse attention policy for nano-vllm.
Implements the XAttention algorithm from COMPASS, using chunked estimation
and block sparse attention for efficient long-context inference.
Architecture:
XAttention = Estimate (Triton) + Compute (BSA)
- Estimate: xattn_estimate() computes block-level importance scores
- Compute: block_sparse_attn_func() executes sparse attention
Reference: COMPASS/compass/src/Xattention.py
"""
import math
from typing import Optional
import torch
import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import AttentionPolicy
# BSA block size is fixed at 128 (hardcoded in block_sparse_attn)
BSA_BLOCK_SIZE = 128
class XAttentionPolicy(AttentionPolicy):
"""
XAttention sparse prefill policy using chunked estimation + block sparse attention.
This policy estimates sparse attention patterns by:
1. Chunked QK computation using Triton kernels (via nanovllm.ops.xattn)
2. Block-wise softmax with importance scores
3. Block selection based on threshold
4. Block sparse attention computation using MIT-HAN-LAB BSA library
The key method is estimate() which calls xattn_estimate() from nanovllm.ops
to compute the sparse attention mask.
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
BSA library: https://github.com/mit-han-lab/Block-Sparse-Attention
"""
supports_prefill = True
supports_decode = True # Uses default FlashAttention for decode
def __init__(
self,
stride: int = 8,
threshold: float = 0.9,
block_size: int = 128,
chunk_size: int = 16384,
use_triton: bool = True,
keep_sink: bool = False,
keep_recent: bool = False,
norm: float = 1.0,
use_bsa: bool = True,
):
"""
Initialize XAttention policy.
Args:
stride: Stride for reorganizing Q/K (default: 8)
threshold: Block selection threshold, 0-1 (default: 0.9)
block_size: Block size for sparse attention (default: 128, must match BSA)
chunk_size: Chunk size for estimation (default: 16384)
use_triton: Use Triton kernels (requires SM 80+)
keep_sink: Always keep first block (sink tokens)
keep_recent: Always keep recent diagonal blocks
norm: Normalization factor for attention scores
use_bsa: Use Block Sparse Attention library (default: True)
"""
self.stride = stride
self.threshold = threshold
self.block_size = block_size
self.chunk_size = chunk_size
self.use_triton = use_triton
self.keep_sink = keep_sink
self.keep_recent = keep_recent
self.norm = norm
self.use_bsa = use_bsa
# BSA requires block_size = 128
if self.use_bsa and self.block_size != BSA_BLOCK_SIZE:
print(f"XAttention: BSA requires block_size=128, adjusting from {self.block_size}")
self.block_size = BSA_BLOCK_SIZE
# Check Triton availability
if self.use_triton:
try:
import triton
props = torch.cuda.get_device_properties(torch.cuda.current_device())
if props.major < 8:
self.use_triton = False
print(f"XAttention: Triton requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.")
except ImportError:
self.use_triton = False
print("XAttention: Triton not available. Falling back to PyTorch.")
# Check BSA availability
if self.use_bsa:
try:
from block_sparse_attn import block_sparse_attn_func
except ImportError:
self.use_bsa = False
print("XAttention: block_sparse_attn not available. Falling back to FlashAttention.")
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Estimate sparse attention mask using XAttention algorithm.
Calls xattn_estimate() from nanovllm.ops.xattn to compute block-level
importance scores and generate a sparse boolean mask.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
or None if estimation fails (fallback to full attention)
"""
try:
from nanovllm.ops.xattn import xattn_estimate
seq_len, num_heads, head_dim = q.shape
num_kv_heads = k.shape[1]
# Convert to [batch, heads, seq, dim] format expected by xattn_estimate
q_bhsd = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
k_bhsd = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, seq_len, head_dim]
# Handle GQA: expand k to match q heads for estimation
if num_kv_heads != num_heads:
# GQA: expand k by repeating
repeat_factor = num_heads // num_kv_heads
k_bhsd = k_bhsd.repeat(1, repeat_factor, 1, 1)
# Call xattn_estimate
attn_sums, sparse_mask = xattn_estimate(
q_bhsd, k_bhsd,
block_size=self.block_size,
stride=self.stride,
norm=self.norm,
threshold=self.threshold,
chunk_size=self.chunk_size,
use_triton=self.use_triton,
causal=True,
keep_sink=self.keep_sink,
keep_recent=self.keep_recent,
)
return sparse_mask
except Exception as e:
# If estimation fails, return None to use full attention
print(f"XAttention estimate failed: {e}, falling back to full attention")
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute XAttention sparse prefill attention.
Flow:
1. Call estimate() to get sparse mask
2. If mask is None or BSA unavailable, use full FlashAttention
3. Otherwise, use block_sparse_attn_func with mask
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
# If BSA is disabled, use full attention directly (skip estimation)
if not self.use_bsa:
return self._full_attention(q, k, v, softmax_scale)
# Step 1: Estimate sparse mask
sparse_mask = self.estimate(q, k, layer_id)
# Step 2: Compute attention
if sparse_mask is None:
# Estimation failed, fallback to full FlashAttention
return self._full_attention(q, k, v, softmax_scale)
# Use block sparse attention with mask
return self._block_sparse_attention(q, k, v, sparse_mask, softmax_scale)
def _block_sparse_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sparse_mask: torch.Tensor,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute block sparse attention using MIT-HAN-LAB BSA library.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
sparse_mask: Block mask [batch, num_heads, q_blocks, k_blocks]
softmax_scale: Softmax scaling factor
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from block_sparse_attn import block_sparse_attn_func
seq_len, num_heads, head_dim = q.shape
num_kv_heads = k.shape[1]
# Handle GQA: expand K/V to match Q heads
if num_kv_heads != num_heads:
repeat_factor = num_heads // num_kv_heads
k = k.repeat_interleave(repeat_factor, dim=1)
v = v.repeat_interleave(repeat_factor, dim=1)
# Cumulative sequence lengths (batch=1)
cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
# Head mask type: 1 for all heads using block sparse
head_mask_type = torch.ones(num_heads, dtype=torch.int32, device=q.device)
# Trim sparse_mask to actual block counts
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
block_mask = sparse_mask[:, :, :q_blocks, :k_blocks].contiguous()
# Call BSA
attn_output = block_sparse_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
None, # streaming_info (left_mask)
block_mask,
seq_len, seq_len,
p_dropout=0.0,
deterministic=True,
softmax_scale=softmax_scale,
is_causal=True,
)
return attn_output
def _full_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute full causal attention using FlashAttention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
softmax_scale: Softmax scaling factor
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def reset(self) -> None:
"""Reset policy state (no state to reset for XAttention)."""
pass
def __repr__(self) -> str:
return (f"XAttentionPolicy("
f"stride={self.stride}, "
f"threshold={self.threshold}, "
f"block_size={self.block_size}, "
f"use_triton={self.use_triton}, "
f"use_bsa={self.use_bsa})")