311 lines
11 KiB
Python
311 lines
11 KiB
Python
"""
|
|
XAttention sparse attention policy for nano-vllm.
|
|
|
|
Implements the XAttention algorithm from COMPASS, using chunked estimation
|
|
and block sparse attention for efficient long-context inference.
|
|
|
|
Architecture:
|
|
XAttention = Estimate (Triton) + Compute (BSA)
|
|
- Estimate: xattn_estimate() computes block-level importance scores
|
|
- Compute: block_sparse_attn_func() executes sparse attention
|
|
|
|
Reference: COMPASS/compass/src/Xattention.py
|
|
"""
|
|
|
|
import math
|
|
from typing import Optional
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from nanovllm.kvcache.sparse.policy import AttentionPolicy
|
|
|
|
# BSA block size is fixed at 128 (hardcoded in block_sparse_attn)
|
|
BSA_BLOCK_SIZE = 128
|
|
|
|
|
|
class XAttentionPolicy(AttentionPolicy):
|
|
"""
|
|
XAttention sparse prefill policy using chunked estimation + block sparse attention.
|
|
|
|
This policy estimates sparse attention patterns by:
|
|
1. Chunked QK computation using Triton kernels (via nanovllm.ops.xattn)
|
|
2. Block-wise softmax with importance scores
|
|
3. Block selection based on threshold
|
|
4. Block sparse attention computation using MIT-HAN-LAB BSA library
|
|
|
|
The key method is estimate() which calls xattn_estimate() from nanovllm.ops
|
|
to compute the sparse attention mask.
|
|
|
|
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
|
|
BSA library: https://github.com/mit-han-lab/Block-Sparse-Attention
|
|
"""
|
|
|
|
supports_prefill = True
|
|
supports_decode = True # Uses default FlashAttention for decode
|
|
|
|
def __init__(
|
|
self,
|
|
stride: int = 8,
|
|
threshold: float = 0.9,
|
|
block_size: int = 128,
|
|
chunk_size: int = 16384,
|
|
use_triton: bool = True,
|
|
keep_sink: bool = False,
|
|
keep_recent: bool = False,
|
|
norm: float = 1.0,
|
|
use_bsa: bool = True,
|
|
):
|
|
"""
|
|
Initialize XAttention policy.
|
|
|
|
Args:
|
|
stride: Stride for reorganizing Q/K (default: 8)
|
|
threshold: Block selection threshold, 0-1 (default: 0.9)
|
|
block_size: Block size for sparse attention (default: 128, must match BSA)
|
|
chunk_size: Chunk size for estimation (default: 16384)
|
|
use_triton: Use Triton kernels (requires SM 80+)
|
|
keep_sink: Always keep first block (sink tokens)
|
|
keep_recent: Always keep recent diagonal blocks
|
|
norm: Normalization factor for attention scores
|
|
use_bsa: Use Block Sparse Attention library (default: True)
|
|
"""
|
|
self.stride = stride
|
|
self.threshold = threshold
|
|
self.block_size = block_size
|
|
self.chunk_size = chunk_size
|
|
self.use_triton = use_triton
|
|
self.keep_sink = keep_sink
|
|
self.keep_recent = keep_recent
|
|
self.norm = norm
|
|
self.use_bsa = use_bsa
|
|
|
|
# BSA requires block_size = 128
|
|
if self.use_bsa and self.block_size != BSA_BLOCK_SIZE:
|
|
print(f"XAttention: BSA requires block_size=128, adjusting from {self.block_size}")
|
|
self.block_size = BSA_BLOCK_SIZE
|
|
|
|
# Check Triton availability
|
|
if self.use_triton:
|
|
try:
|
|
import triton
|
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
if props.major < 8:
|
|
self.use_triton = False
|
|
print(f"XAttention: Triton requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.")
|
|
except ImportError:
|
|
self.use_triton = False
|
|
print("XAttention: Triton not available. Falling back to PyTorch.")
|
|
|
|
# Check BSA availability
|
|
if self.use_bsa:
|
|
try:
|
|
from block_sparse_attn import block_sparse_attn_func
|
|
except ImportError:
|
|
self.use_bsa = False
|
|
print("XAttention: block_sparse_attn not available. Falling back to FlashAttention.")
|
|
|
|
def estimate(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
layer_id: int,
|
|
) -> Optional[torch.Tensor]:
|
|
"""
|
|
Estimate sparse attention mask using XAttention algorithm.
|
|
|
|
Calls xattn_estimate() from nanovllm.ops.xattn to compute block-level
|
|
importance scores and generate a sparse boolean mask.
|
|
|
|
Args:
|
|
q: Query tensor [seq_len, num_heads, head_dim]
|
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
layer_id: Transformer layer index
|
|
|
|
Returns:
|
|
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
|
|
or None if estimation fails (fallback to full attention)
|
|
"""
|
|
try:
|
|
from nanovllm.ops.xattn import xattn_estimate
|
|
|
|
seq_len, num_heads, head_dim = q.shape
|
|
num_kv_heads = k.shape[1]
|
|
|
|
# Convert to [batch, heads, seq, dim] format expected by xattn_estimate
|
|
q_bhsd = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
|
|
k_bhsd = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, seq_len, head_dim]
|
|
|
|
# Handle GQA: expand k to match q heads for estimation
|
|
if num_kv_heads != num_heads:
|
|
# GQA: expand k by repeating
|
|
repeat_factor = num_heads // num_kv_heads
|
|
k_bhsd = k_bhsd.repeat(1, repeat_factor, 1, 1)
|
|
|
|
# Call xattn_estimate
|
|
attn_sums, sparse_mask = xattn_estimate(
|
|
q_bhsd, k_bhsd,
|
|
block_size=self.block_size,
|
|
stride=self.stride,
|
|
norm=self.norm,
|
|
threshold=self.threshold,
|
|
chunk_size=self.chunk_size,
|
|
use_triton=self.use_triton,
|
|
causal=True,
|
|
keep_sink=self.keep_sink,
|
|
keep_recent=self.keep_recent,
|
|
)
|
|
|
|
return sparse_mask
|
|
|
|
except Exception as e:
|
|
# If estimation fails, return None to use full attention
|
|
print(f"XAttention estimate failed: {e}, falling back to full attention")
|
|
return None
|
|
|
|
def compute_prefill(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer_id: int,
|
|
softmax_scale: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute XAttention sparse prefill attention.
|
|
|
|
Flow:
|
|
1. Call estimate() to get sparse mask
|
|
2. If mask is None or BSA unavailable, use full FlashAttention
|
|
3. Otherwise, use block_sparse_attn_func with mask
|
|
|
|
Args:
|
|
q: Query tensor [seq_len, num_heads, head_dim]
|
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
layer_id: Transformer layer index
|
|
softmax_scale: Softmax scaling factor
|
|
|
|
Returns:
|
|
Attention output [seq_len, num_heads, head_dim]
|
|
"""
|
|
# If BSA is disabled, use full attention directly (skip estimation)
|
|
if not self.use_bsa:
|
|
return self._full_attention(q, k, v, softmax_scale)
|
|
|
|
# Step 1: Estimate sparse mask
|
|
sparse_mask = self.estimate(q, k, layer_id)
|
|
|
|
# Step 2: Compute attention
|
|
if sparse_mask is None:
|
|
# Estimation failed, fallback to full FlashAttention
|
|
return self._full_attention(q, k, v, softmax_scale)
|
|
|
|
# Use block sparse attention with mask
|
|
return self._block_sparse_attention(q, k, v, sparse_mask, softmax_scale)
|
|
|
|
def _block_sparse_attention(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
sparse_mask: torch.Tensor,
|
|
softmax_scale: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute block sparse attention using MIT-HAN-LAB BSA library.
|
|
|
|
Args:
|
|
q: Query tensor [seq_len, num_heads, head_dim]
|
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
sparse_mask: Block mask [batch, num_heads, q_blocks, k_blocks]
|
|
softmax_scale: Softmax scaling factor
|
|
|
|
Returns:
|
|
Attention output [seq_len, num_heads, head_dim]
|
|
"""
|
|
from block_sparse_attn import block_sparse_attn_func
|
|
|
|
seq_len, num_heads, head_dim = q.shape
|
|
num_kv_heads = k.shape[1]
|
|
|
|
# Handle GQA: expand K/V to match Q heads
|
|
if num_kv_heads != num_heads:
|
|
repeat_factor = num_heads // num_kv_heads
|
|
k = k.repeat_interleave(repeat_factor, dim=1)
|
|
v = v.repeat_interleave(repeat_factor, dim=1)
|
|
|
|
# Cumulative sequence lengths (batch=1)
|
|
cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
|
cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
|
|
|
# Head mask type: 1 for all heads using block sparse
|
|
head_mask_type = torch.ones(num_heads, dtype=torch.int32, device=q.device)
|
|
|
|
# Trim sparse_mask to actual block counts
|
|
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
|
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
|
block_mask = sparse_mask[:, :, :q_blocks, :k_blocks].contiguous()
|
|
|
|
# Call BSA
|
|
attn_output = block_sparse_attn_func(
|
|
q, k, v,
|
|
cu_seqlens_q, cu_seqlens_k,
|
|
head_mask_type,
|
|
None, # streaming_info (left_mask)
|
|
block_mask,
|
|
seq_len, seq_len,
|
|
p_dropout=0.0,
|
|
deterministic=True,
|
|
softmax_scale=softmax_scale,
|
|
is_causal=True,
|
|
)
|
|
|
|
return attn_output
|
|
|
|
def _full_attention(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
softmax_scale: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute full causal attention using FlashAttention.
|
|
|
|
Args:
|
|
q: Query tensor [seq_len, num_heads, head_dim]
|
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
softmax_scale: Softmax scaling factor
|
|
|
|
Returns:
|
|
Attention output [seq_len, num_heads, head_dim]
|
|
"""
|
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
|
|
seq_len = q.shape[0]
|
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
|
|
|
return flash_attn_varlen_func(
|
|
q, k, v,
|
|
cu_seqlens_q=cu_seqlens,
|
|
cu_seqlens_k=cu_seqlens,
|
|
max_seqlen_q=seq_len,
|
|
max_seqlen_k=seq_len,
|
|
softmax_scale=softmax_scale,
|
|
causal=True,
|
|
)
|
|
|
|
def reset(self) -> None:
|
|
"""Reset policy state (no state to reset for XAttention)."""
|
|
pass
|
|
|
|
def __repr__(self) -> str:
|
|
return (f"XAttentionPolicy("
|
|
f"stride={self.stride}, "
|
|
f"threshold={self.threshold}, "
|
|
f"block_size={self.block_size}, "
|
|
f"use_triton={self.use_triton}, "
|
|
f"use_bsa={self.use_bsa})")
|