Files
nano-vllm/nanovllm/kvcache/sparse/xattn_bsa.py
Zijie Tian 11a867f6fb 🐛 fix: skip GQA buffer allocation in XAttention offload mode
In offload mode, GQA expansion buffers (_k_expanded, _v_expanded) are not
needed since compute_chunked_prefill() handles GQA inline. Previously,
these buffers were always allocated based on max_model_len, causing OOM
on 24GB GPUs (e.g., RTX 3090) when max_model_len=1M (16GB buffer).

Changes:
- Add enable_cpu_offload parameter to alloc_policy_metadata() in base class
- Skip GQA buffer allocation when enable_cpu_offload=True in XAttentionBSAPolicy
- Pass enable_cpu_offload from model_runner to policy

Memory savings: ~16GB for 1M seq, ~1.1GB for 72K seq

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 02:57:18 +08:00

1131 lines
50 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
XAttention Block Sparse Attention (BSA) Policy for nano-vllm.
This module implements XAttention-inspired block sparse attention for chunked prefill.
Key design:
1. Use xattn_estimate_chunked to estimate sparse block mask
2. Use BSA kernel for efficient sparse attention computation
3. Support chunked prefill with q_start_pos for correct position handling
Note: Decode phase is not supported - use FullAttentionPolicy for decode.
"""
import logging
import torch
import torch.cuda.nvtx as nvtx
from typing import List, Tuple, TYPE_CHECKING
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.utils.density_observer import DensityObserver
if TYPE_CHECKING:
from nanovllm.kvcache.offload_engine import OffloadEngine
from nanovllm.kvcache.manager import KVCacheManager
from nanovllm.engine.sequence import Sequence
logger = logging.getLogger(__name__)
# Global storage for mask debugging
_DEBUG_SAVE_MASK = False # Set to True to save masks for comparison
_DEBUG_MASK_STORAGE = {}
# Check BSA availability
try:
from block_sparse_attn import block_sparse_attn_func
BSA_AVAILABLE = True
except ImportError:
BSA_AVAILABLE = False
logger.warning("block_sparse_attn not available, XAttentionBSAPolicy will fallback to dense")
# Check xattn_estimate_chunked availability
try:
from nanovllm.ops.xattn import xattn_estimate_chunked
XATTN_AVAILABLE = True
except ImportError:
XATTN_AVAILABLE = False
logger.warning("xattn_estimate_chunked not available")
def expand_kv_for_gqa(
key_states: torch.Tensor,
value_states: torch.Tensor,
num_heads: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand KV for Grouped Query Attention.
Args:
key_states: [B, num_kv_heads, seq_len, head_dim]
value_states: [B, num_kv_heads, seq_len, head_dim]
num_heads: Number of query heads
Returns:
Expanded (key, value) with shape [B, num_heads, seq_len, head_dim]
"""
num_kv_heads = key_states.shape[1]
if num_heads == num_kv_heads:
return key_states, value_states
num_groups = num_heads // num_kv_heads
return (
key_states.repeat_interleave(num_groups, dim=1),
value_states.repeat_interleave(num_groups, dim=1),
)
class XAttentionBSAPolicy(SparsePolicy):
"""
XAttention Block Sparse Attention policy for chunked prefill.
Uses xattn_estimate_chunked to estimate sparse mask, then BSA kernel
for efficient sparse attention computation.
Note:
- Only supports prefill phase (decode uses FullAttentionPolicy)
- BSA block size is fixed at 128 tokens
"""
supports_prefill = True
supports_decode = False # Decode uses FullAttentionPolicy
requires_block_selection = False # Selection happens internally
# BSA requires 128-token blocks
BSA_BLOCK_SIZE = 128
def __init__(
self,
threshold: float = 0.95, # High threshold for accuracy testing
stride: int = 8,
chunk_size: int = 4096, # Match offload Q chunk size for density alignment
block_size: int = 128,
samples_per_chunk: int = 128,
use_triton: bool = True,
estimate_block_size: int = 1024, # Optimized block size for softmax_fuse_block_sum
):
"""
Initialize XAttention BSA policy.
Args:
threshold: Cumulative attention threshold for block selection (0-1)
Higher values = more blocks selected = less sparse
stride: Stride for Q/K reshape in estimation (typically 8)
chunk_size: Processing chunk size for xattn_estimate (Triton alignment)
block_size: BSA block size (must be 128)
samples_per_chunk: Samples per chunk for estimation (unused)
use_triton: Whether to use Triton kernels
estimate_block_size: Block size for softmax_fuse_block_sum in select_blocks.
Default 1024 is optimal (15x faster than 4096).
Must be a factor of cpu_block_size (e.g., 4096/1024=4).
"""
self.threshold = threshold
self.stride = stride
self.chunk_size = chunk_size
self.use_triton = use_triton
self.estimate_block_size = estimate_block_size
self._num_heads = None # Set during first forward
# Sparse metadata: stores attention scores per layer
# Dict[layer_id, Tensor[num_q_blocks, num_k_blocks]]
self.sparse_metadata: dict = {}
# Statistics for density tracking
self._stats_total_available_blocks = 0
self._stats_total_selected_blocks = 0
self._stats_num_chunks = 0
# Pre-allocated GQA expansion buffers (GPU-only mode)
# Set by alloc_policy_metadata(), None if not pre-allocated
self._k_expanded: torch.Tensor | None = None
self._v_expanded: torch.Tensor | None = None
self._max_seq_len: int = 0
# Pre-allocated mask buffer for chunked prefill (offload mode)
# Stores BSA-level mask from select_blocks for use in compute_chunked_prefill
# Shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
self._prefill_mask_buffer: torch.Tensor | None = None
self._current_mask_q_bsa: int = 0 # Current Q BSA blocks in buffer
self._current_mask_k_bsa: int = 0 # Current K BSA blocks in buffer
# Selected block indices for mask extraction in compute_chunked_prefill
# Stores the indices of selected CPU blocks in available_blocks
self._selected_cpu_indices: List[int] = []
self._bsa_per_cpu: int = 0 # BSA blocks per CPU block
# =====================================================================
# Pre-allocated buffers for 3-stage KV chunking (offload mode)
# =====================================================================
# Partial softmax stats: m (max) and l (exp sum) for each KV chunk
# Shape: [max_kv_chunks, batch, heads, q_reshaped_len]
self._m_partial_buffer: torch.Tensor | None = None
self._l_partial_buffer: torch.Tensor | None = None
# Block sums buffer: normalized attention sums for all K blocks
# Shape: [batch, heads, max_q_bsa_blocks, max_k_bsa_blocks]
self._block_sums_buffer: torch.Tensor | None = None
# Configuration for KV chunking
self._max_kv_chunks: int = 0
self._cpu_block_size: int = 0 # Tokens per CPU block (set at runtime)
def alloc_policy_metadata(
self,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
enable_cpu_offload: bool = False,
) -> None:
"""
Pre-allocate GQA expansion buffers for GPU-only mode.
These buffers are used by compute_prefill() to avoid dynamic allocation
during forward pass. The buffers are sized for max_seq_len and sliced
to actual seq_len during use.
Memory usage: 2 * num_heads * max_seq_len * head_dim * dtype_size
For 64K seq, 32 heads, 128 dim, fp16: 2 * 32 * 65536 * 128 * 2 = 1 GB
Args:
num_heads: Number of query heads
num_kv_heads: Number of KV heads (for GQA)
head_dim: Dimension per head
max_seq_len: Maximum sequence length
dtype: Data type
device: Target device
"""
# Pre-allocate mask buffer for chunked prefill (offload mode)
# mask shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
# This is needed regardless of GQA
max_q_bsa_blocks = self.chunk_size // self.BSA_BLOCK_SIZE
max_k_bsa_blocks = max_seq_len // self.BSA_BLOCK_SIZE
mask_shape = (1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks)
self._prefill_mask_buffer = torch.empty(mask_shape, dtype=torch.bool, device=device)
mask_memory_mb = num_heads * max_q_bsa_blocks * max_k_bsa_blocks / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated mask buffer: shape={mask_shape}, memory={mask_memory_mb:.1f} MB")
# =====================================================================
# Pre-allocate buffers for 3-stage KV chunking (offload mode)
# =====================================================================
# Calculate max KV chunks: historical blocks + current chunk
# Use cpu_block_size as KV chunk granularity (will be set at runtime)
# For now, estimate based on chunk_size (actual cpu_block_size may differ)
estimated_cpu_block_size = 4096 # Default, will be overwritten
max_kv_chunks = (max_seq_len // estimated_cpu_block_size) + 1 # +1 for current chunk
# Q reshaped length for one chunk
q_reshaped_len = self.chunk_size // self.stride
kv_chunk_reshaped_len = estimated_cpu_block_size // self.stride
# Partial stats buffers: [max_kv_chunks, batch=1, heads, q_reshaped_len]
m_partial_shape = (max_kv_chunks, 1, num_heads, q_reshaped_len)
self._m_partial_buffer = torch.empty(m_partial_shape, dtype=torch.float32, device=device)
self._l_partial_buffer = torch.empty(m_partial_shape, dtype=torch.float32, device=device)
# Block sums buffer: [batch=1, heads, max_q_bsa_blocks, max_k_bsa_blocks]
block_sums_shape = (1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks)
self._block_sums_buffer = torch.empty(block_sums_shape, dtype=dtype, device=device)
self._max_kv_chunks = max_kv_chunks
# Memory calculation
m_l_memory_mb = 2 * max_kv_chunks * num_heads * q_reshaped_len * 4 / (1024 * 1024)
block_sums_memory_mb = num_heads * max_q_bsa_blocks * max_k_bsa_blocks * dtype.itemsize / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated KV chunking buffers: "
f"m/l shape={m_partial_shape} ({m_l_memory_mb:.1f} MB), "
f"block_sums shape={block_sums_shape} ({block_sums_memory_mb:.1f} MB)")
# Skip GQA buffers in offload mode
# Chunked prefill uses compute_chunked_prefill() which handles GQA inline
if enable_cpu_offload:
logger.info("[XAttn] Offload mode: skipping GQA expansion buffers (saves ~16GB for 1M seq)")
return
# GPU-only mode: pre-allocate GQA buffers for compute_prefill()
# Only allocate if GQA (num_heads != num_kv_heads)
if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
return
# Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format
# Also used for BSA which expects [seq_len, num_heads, head_dim]
shape = (1, num_heads, max_seq_len, head_dim)
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
self._max_seq_len = max_seq_len
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
# =========================================================================
# GPU-only methods (non-chunked)
# =========================================================================
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
layer_id: int,
block_tables: torch.Tensor = None,
) -> torch.Tensor:
"""
GPU-only prefill attention using XAttention + BSA.
This method implements sparse attention for GPU-only mode:
1. Estimate block importance using xattn_estimate
2. Compute sparse attention using block_sparse_attn_func
Args:
q: Query tensor [total_q, num_heads, head_dim] (varlen packed)
k: Key tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
v: Value tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
max_seqlen_q: Maximum Q sequence length
max_seqlen_k: Maximum K sequence length
softmax_scale: Softmax scaling factor
layer_id: Transformer layer index
block_tables: Paged attention block tables (not used for XAttention)
Returns:
Attention output [total_q, num_heads, head_dim]
"""
# Fallback to flash attention when:
# 1. block_tables provided (paged KV cache / prefix cache) - XAttention expects contiguous K, V
# 2. BSA kernel not available
# 3. xattn_estimate not available
if block_tables is not None or not BSA_AVAILABLE or not XATTN_AVAILABLE:
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
block_table=block_tables,
)
from nanovllm.ops.xattn import xattn_estimate
# Set DensityObserver mode on first layer
if layer_id == 0:
DensityObserver.set_mode("gpu_only")
# Get dimensions
total_q, num_heads, head_dim = q.shape
total_kv, num_kv_heads, _ = k.shape
# For now, assume batch_size = 1 (single sequence)
# TODO: Support batched varlen format
batch_size = cu_seqlens_q.shape[0] - 1
if batch_size != 1:
# Fallback to flash attention for batched input
from flash_attn import flash_attn_varlen_func
logger.warning(f"[XAttn] batch_size={batch_size} > 1, falling back to flash attention")
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
)
q_len = max_seqlen_q
k_len = max_seqlen_k
# Convert from varlen format [total, heads, dim] to [batch, heads, seq, dim]
# q: [q_len, num_heads, head_dim] -> [1, num_heads, q_len, head_dim]
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, q_len, head_dim]
K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
# Expand KV for GQA - use pre-allocated buffers if available
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
if self._k_expanded is not None and k_len <= self._max_seq_len:
# Use pre-allocated buffers with in-place expansion
K_exp = self._k_expanded[:, :, :k_len, :]
V_exp = self._v_expanded[:, :, :k_len, :]
# In-place GQA expansion: [1, num_kv_heads, k_len, head_dim] -> [1, num_heads, k_len, head_dim]
# Reshape K to [1, num_kv_heads, 1, k_len, head_dim] and broadcast to [1, num_kv_heads, num_groups, k_len, head_dim]
K_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
K.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
)
V_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
V.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
)
else:
# Fallback: dynamic allocation (when buffers not pre-allocated or seq too long)
K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads)
else:
K_exp, V_exp = K, V
# Estimate block importance and get sparse mask
with nvtx.range("xattn_estimate"):
attn_sums, mask = xattn_estimate(
Q, K_exp,
chunk_size=self.chunk_size,
block_size=self.BSA_BLOCK_SIZE,
stride=self.stride,
threshold=self.threshold,
use_triton=self.use_triton,
causal=True,
)
# Debug: Save Q, K, mask, attn_sums for external verification
if _DEBUG_SAVE_MASK and layer_id == 0:
import os
valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
valid_k_blocks = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
mask_valid = mask[:, :, :valid_q_blocks, :valid_k_blocks]
attn_sums_valid = attn_sums[:, :, :valid_q_blocks, :valid_k_blocks]
save_dir = "/home/zijie/Code/nano-vllm/results/mask_alignment"
os.makedirs(save_dir, exist_ok=True)
save_path = f"{save_dir}/gpuonly_layer{layer_id}.pt"
torch.save({
# Input tensors (GQA-expanded)
"Q": Q.clone().cpu(), # [1, num_heads, q_len, head_dim]
"K": K_exp.clone().cpu(), # [1, num_heads, k_len, head_dim]
# xattn_estimate parameters
"chunk_size": self.chunk_size,
"block_size": self.BSA_BLOCK_SIZE,
"stride": self.stride,
"threshold": self.threshold,
# Output for comparison
"mask": mask_valid.clone().cpu(),
"attn_sums": attn_sums_valid.clone().cpu(),
# Metadata
"q_len": q_len,
"k_len": k_len,
"valid_q_blocks": valid_q_blocks,
"valid_k_blocks": valid_k_blocks,
}, save_path)
logger.info(f"[DEBUG] Saved Q/K/mask to {save_path}, Q={Q.shape}, K={K_exp.shape}")
# Compute block counts
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
k_block_num = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
# Prepare tensors for BSA
# q, k, v need to be [seq_len, num_heads, head_dim]
q_bsa = q # Already [q_len, num_heads, head_dim]
# For GQA with BSA, reuse the expanded K_exp, V_exp (convert to BSA format)
# K_exp: [1, num_heads, k_len, head_dim] -> [k_len, num_heads, head_dim]
if num_heads != num_kv_heads:
k_bsa = K_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
v_bsa = V_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
else:
k_bsa = k
v_bsa = v
# Prepare BSA inputs
cu_seqlens_q_bsa = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
cu_seqlens_k_bsa = torch.tensor([0, k_len], dtype=torch.int32, device=k.device)
head_groups = torch.ones(num_heads, dtype=torch.int32, device=q.device)
# Trim mask to actual block counts
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
# Compute sparse attention using BSA
with nvtx.range("xattn_bsa_compute"):
output = block_sparse_attn_func(
q_bsa, k_bsa, v_bsa,
cu_seqlens_q_bsa,
cu_seqlens_k_bsa,
head_groups,
None, # key_padding_mask
mask_trimmed,
q_len, k_len,
p_dropout=0.0,
deterministic=True,
is_causal=True,
)
# Record density for all layers via DensityObserver
if layer_id == 0:
# DEBUG: 打印 GPU-only Layer 0 的 mask 详情
q_bk = mask_trimmed.shape[2]
k_bk = mask_trimmed.shape[3]
causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1]
causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool))
selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
DensityObserver.record(layer_id, mask_trimmed, causal=True)
return output
def compute_decode(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
softmax_scale: float,
layer_id: int,
block_tables: torch.Tensor = None,
) -> torch.Tensor:
"""
GPU-only decode attention - delegates to FullAttentionPolicy.
XAttention is designed for long prefill sequences. For decode (single token),
we use FullAttentionPolicy which calls flash_attn_with_kvcache.
"""
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
return FullAttentionPolicy().compute_decode(
q, k_cache, v_cache, cache_seqlens, softmax_scale, layer_id, block_tables
)
# =========================================================================
# Chunked offload methods
# =========================================================================
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]:
"""
Select important blocks using 3-stage KV chunking algorithm.
This method implements the same algorithm as tests/test_xattn_estimate_alignment.py:
1. For each KV chunk: compute attention scores and partial softmax stats
2. Merge all partial stats to get global m and l
3. For each KV chunk: normalize with global stats and compute block sums
4. Use find_blocks_chunked to select important blocks
This approach:
- Uses O(S×C) peak memory instead of O(S²)
- Produces identical density to GPU-only xattn_estimate
- Supports ultra-long contexts
Args:
available_blocks: List of CPU block IDs (historical blocks only)
offload_engine: OffloadEngine for loading blocks
ctx: PolicyContext with metadata
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk
Returns:
Selected block IDs based on attention threshold
"""
if q is None:
return available_blocks
# CRITICAL: Wait for all previous prefill offloads to complete before loading from CPU
# This ensures that the K data we load from k_cache_cpu is actually valid.
# Without this sync, we may load stale/uninitialized data because the async offload
# from the previous chunk hasn't finished yet.
if available_blocks and offload_engine is not None:
offload_engine.wait_all_prefill_offloads()
from nanovllm.ops.xattn import (
flat_group_gemm_fuse_reshape,
softmax_compute_partial_stats,
softmax_normalize_and_block_sum,
merge_softmax_stats,
find_blocks_chunked,
)
import math
layer_id = ctx.layer_id
# Set DensityObserver mode on first layer
if layer_id == 0:
DensityObserver.set_mode("offload")
# ================================================================
# Step 0: Setup parameters
# ================================================================
# Convert Q to [batch, heads, seq_len, head_dim]
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, q_len, head_dim]
num_heads = Q.shape[1]
head_dim = Q.shape[3]
q_len = Q.shape[2]
# Alignment requirements
BLOCK_M = 128 # Triton block size
alignment = self.stride * BLOCK_M # 8 * 128 = 1024
if q_len < alignment:
# Q too short, skip estimation and return all blocks
logger.debug(f"[XAttn] select_blocks: q_len={q_len} < alignment={alignment}, skipping estimation")
return available_blocks
# Pad Q to alignment
padded_q_len = ((q_len + alignment - 1) // alignment) * alignment
q_pad_size = padded_q_len - q_len
if q_pad_size > 0:
Q = torch.nn.functional.pad(Q, (0, 0, 0, q_pad_size), value=0)
# Get CPU block size from context
cpu_block_size = ctx.block_size # e.g., 4096 tokens per CPU block
self._cpu_block_size = cpu_block_size
# KV chunk parameters (use CPU block as KV chunk unit)
num_historical_blocks = len(available_blocks)
historical_k_len = num_historical_blocks * cpu_block_size
total_k_len = historical_k_len + q_len # Include current chunk
# Reshaped dimensions
reshaped_block_size = self.BSA_BLOCK_SIZE // self.stride # 128/8 = 16
q_reshaped_len = padded_q_len // self.stride
kv_chunk_reshaped = cpu_block_size // self.stride
# BSA blocks per CPU block
bsa_per_cpu = cpu_block_size // self.BSA_BLOCK_SIZE # 4096/128 = 32
# Global K position parameters
# Q在全局K序列中的位置 (按照 test_xattn_estimate_alignment.py 的逻辑)
# 对于 chunked softmax我们需要计算 Q 在整个序列中的 BSA block 偏移
# k_block_num = total BSA blocks (padded), q_block_num = Q's BSA blocks (padded)
padded_total_k_len = ((total_k_len + alignment - 1) // alignment) * alignment
k_block_num = padded_total_k_len // self.BSA_BLOCK_SIZE
q_block_num = padded_q_len // self.BSA_BLOCK_SIZE
chunk_start = (k_block_num - q_block_num) * reshaped_block_size # Q 在 reshaped 空间的起始
chunk_end = chunk_start + q_reshaped_len
# real_q_len: 用于 softmax 归一化的有效 Q 长度
k_reshaped_seq_len = padded_total_k_len // self.stride
k_reshaped_num_to_pad = (padded_total_k_len - total_k_len) // self.stride
# Softmax scale
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm
segment_size = min(4096, reshaped_block_size)
# ================================================================
# Step 1: First pass - compute partial stats for all KV chunks
# ================================================================
m_chunks = []
l_chunks = []
num_kv_chunks = num_historical_blocks + 1 # +1 for current chunk
# Get compute_stream for all compute kernels (like attention computation)
compute_stream = offload_engine.compute_stream
with nvtx.range("xattn_estimate_pass1"):
slot = 0
# Process historical blocks (from CPU)
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
# Load K from CPU (on slot_transfer_stream)
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
# wait_slot_layer makes compute_stream wait for H2D transfer
offload_engine.wait_slot_layer(slot)
# All compute kernels run on compute_stream (like attention computation)
with torch.cuda.stream(compute_stream):
k_block = offload_engine.get_k_for_slot(slot) # [1, block_size, num_kv_heads, head_dim]
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
# GQA expansion
num_kv_heads = K_chunk.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
# KV offset in reshaped space
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
# Compute raw attention scores
attn_weights_kv = flat_group_gemm_fuse_reshape(
Q, K_chunk, self.stride,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False, # K 不完整,不能在这里用 causal
)
# Compute partial stats (带 causal mask)
m_partial, l_partial = softmax_compute_partial_stats(
attn_weights_kv,
reshaped_block_size,
segment_size,
scale,
chunk_start=chunk_start,
kv_offset=kv_offset_reshaped,
is_causal=True,
)
m_chunks.append(m_partial)
l_chunks.append(l_partial)
offload_engine.record_slot_compute_done(slot)
del attn_weights_kv
# Process current chunk K (already on GPU) on compute_stream
with torch.cuda.stream(compute_stream):
# k: [seq_len, num_kv_heads, head_dim] -> [1, num_kv_heads, seq_len, head_dim]
K_current = k.unsqueeze(0).transpose(1, 2)
# GQA expansion for current chunk
num_kv_heads = K_current.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_current = K_current.repeat_interleave(num_groups, dim=1)
# Pad current K to alignment
curr_k_len = K_current.shape[2]
padded_curr_k_len = ((curr_k_len + alignment - 1) // alignment) * alignment
if padded_curr_k_len != curr_k_len:
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, padded_curr_k_len - curr_k_len), value=0)
# KV offset for current chunk
kv_offset_current = num_historical_blocks * kv_chunk_reshaped
# Compute attention scores for current chunk
attn_weights_curr = flat_group_gemm_fuse_reshape(
Q, K_current, self.stride,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False,
)
# Compute partial stats for current chunk
m_partial_curr, l_partial_curr = softmax_compute_partial_stats(
attn_weights_curr,
reshaped_block_size,
segment_size,
scale,
chunk_start=chunk_start,
kv_offset=kv_offset_current,
is_causal=True,
)
m_chunks.append(m_partial_curr)
l_chunks.append(l_partial_curr)
del attn_weights_curr
# ================================================================
# Step 2: Merge all partial stats (on compute_stream)
# ================================================================
with torch.cuda.stream(compute_stream):
with nvtx.range("xattn_estimate_merge"):
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
del m_chunks, l_chunks
# ================================================================
# Step 3: Second pass - normalize and compute block sums
# ================================================================
attn_sum_per_kv = []
with nvtx.range("xattn_estimate_pass2"):
slot = 0
# Process historical blocks again
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
# Load K from CPU (on slot_transfer_stream)
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
# wait_slot_layer makes compute_stream wait for H2D transfer
offload_engine.wait_slot_layer(slot)
# All compute kernels run on compute_stream
with torch.cuda.stream(compute_stream):
k_block = offload_engine.get_k_for_slot(slot)
K_chunk = k_block.transpose(1, 2)
num_kv_heads = K_chunk.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
# Recompute attention scores (trade-off: compute vs memory)
attn_weights_kv = flat_group_gemm_fuse_reshape(
Q, K_chunk, self.stride,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False,
)
# Normalize with global stats and compute block sums
block_sum_kv = softmax_normalize_and_block_sum(
attn_weights_kv,
m_global,
l_global,
reshaped_block_size,
segment_size,
chunk_start=chunk_start,
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
scale=scale,
kv_offset=kv_offset_reshaped,
is_causal=True,
)
attn_sum_per_kv.append(block_sum_kv)
offload_engine.record_slot_compute_done(slot)
del attn_weights_kv
# Process current chunk on compute_stream
with torch.cuda.stream(compute_stream):
# Recompute attention scores for current chunk
attn_weights_curr = flat_group_gemm_fuse_reshape(
Q, K_current, self.stride,
chunk_start=chunk_start,
chunk_end=chunk_end,
is_causal=False,
)
block_sum_curr = softmax_normalize_and_block_sum(
attn_weights_curr,
m_global,
l_global,
reshaped_block_size,
segment_size,
chunk_start=chunk_start,
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
scale=scale,
kv_offset=kv_offset_current,
is_causal=True,
)
attn_sum_per_kv.append(block_sum_curr)
del attn_weights_curr, K_current
# ================================================================
# Step 4: Concatenate block sums and select blocks (on compute_stream)
# ================================================================
with torch.cuda.stream(compute_stream):
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
del attn_sum_per_kv, m_global, l_global
# Calculate q_block offset for find_blocks_chunked
# This is the number of BSA blocks before Q in the full sequence
num_blocks_per_chunk = q_reshaped_len // reshaped_block_size
current_index = k_block_num - q_block_num # Q starts at this BSA block index
with nvtx.range("xattn_find_blocks"):
mask = find_blocks_chunked(
attn_sum_concat,
current_index=current_index,
threshold=self.threshold,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=True,
)
# Apply causal mask post-processing (same as xattn.py lines 1300-1306)
mask[:, :, -q_block_num:, -q_block_num:] = torch.where(
torch.tril(torch.ones(q_block_num, q_block_num, dtype=torch.bool, device=mask.device), diagonal=0),
mask[:, :, -q_block_num:, -q_block_num:],
False,
)
# ================================================================
# Step 5: Record density (only on layer 0)
# ================================================================
if layer_id == 0:
# Trim mask to valid region
valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
valid_k_blocks = (total_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
mask_valid = mask[:, :, :valid_q_blocks, :valid_k_blocks]
attn_sums_valid = attn_sum_concat[:, :, :valid_q_blocks, :valid_k_blocks]
# Compute causal mask for density calculation
q_offset_blocks = valid_k_blocks - valid_q_blocks
indices = torch.arange(valid_k_blocks, device=mask.device).unsqueeze(0)
q_indices = torch.arange(valid_q_blocks, device=mask.device).unsqueeze(1)
causal_mask = indices <= (q_indices + q_offset_blocks)
chunk_total = causal_mask.sum().item() * mask_valid.shape[0] * mask_valid.shape[1]
chunk_selected = (mask_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
DensityObserver.record_counts(layer_id, chunk_selected, chunk_total)
logger.info(f"[XAttn Offload] Layer0 chunk: q_len={q_len}, k_len={total_k_len}, "
f"valid_q_blocks={valid_q_blocks}, valid_k_blocks={valid_k_blocks}, "
f"q_offset={q_offset_blocks}, selected={chunk_selected}, total={chunk_total}, "
f"density={chunk_selected/chunk_total:.4f}")
# Debug: Save mask and attention sums for comparison
if _DEBUG_SAVE_MASK:
import os
chunk_idx = ctx.query_chunk_idx if ctx else 0
save_dir = "/home/zijie/Code/nano-vllm/results/mask_alignment"
os.makedirs(save_dir, exist_ok=True)
save_path = f"{save_dir}/offload_layer{layer_id}_chunk{chunk_idx}.pt"
torch.save({
"mask": mask_valid.clone().cpu(),
"attn_sums": attn_sums_valid.clone().cpu(),
"q_len": q_len,
"k_len": total_k_len,
"valid_q_blocks": valid_q_blocks,
"valid_k_blocks": valid_k_blocks,
"current_index": current_index,
"chunk_start": chunk_start,
}, save_path)
logger.info(f"[DEBUG] Saved mask to {save_path}")
del attn_sum_concat
# ================================================================
# Step 6: Extract historical mask and aggregate to CPU blocks
# ================================================================
B, H, Q_bsa, K_bsa_total = mask.shape
historical_k_bsa = num_historical_blocks * bsa_per_cpu
# Save mask to buffer for compute_chunked_prefill (if needed later)
if self._prefill_mask_buffer is not None and historical_k_bsa > 0:
self._prefill_mask_buffer[:, :, :Q_bsa, :historical_k_bsa].copy_(
mask[:, :, :, :historical_k_bsa]
)
self._current_mask_q_bsa = Q_bsa
self._current_mask_k_bsa = historical_k_bsa
# Aggregate to CPU block level (union across heads, Q blocks, BSA blocks per CPU)
if num_historical_blocks == 0:
return []
mask_historical = mask[:, :, :, :historical_k_bsa]
mask_per_cpu = mask_historical.view(B, H, Q_bsa, num_historical_blocks, bsa_per_cpu)
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1) # [B, num_cpu]
selected_indices = cpu_needed[0].nonzero().squeeze(-1).tolist()
if isinstance(selected_indices, int):
selected_indices = [selected_indices]
selected_block_ids = [available_blocks[i] for i in selected_indices]
# Always include first block (sink) and last block for safety
if available_blocks and available_blocks[0] not in selected_block_ids:
selected_block_ids.insert(0, available_blocks[0])
if available_blocks and available_blocks[-1] not in selected_block_ids:
selected_block_ids.append(available_blocks[-1])
# Update statistics (only for layer 0 to avoid overcounting)
if layer_id == 0 and available_blocks:
self._stats_total_available_blocks += len(available_blocks)
self._stats_total_selected_blocks += len(selected_block_ids)
self._stats_num_chunks += 1
# Record communication density to DensityObserver
# Comm density = selected_cpu_blocks / available_cpu_blocks
# This is different from compute density (BSA block granularity)
DensityObserver.record_comm_density(
layer_id=layer_id,
selected_cpu_blocks=len(selected_block_ids),
total_cpu_blocks=len(available_blocks),
)
# Log per-chunk density
chunk_density = len(selected_block_ids) / len(available_blocks)
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, "
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
return selected_block_ids
def compute_chunked_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute attention for chunked prefill using XAttention sparse block selection.
This method handles the chunked prefill computation:
1. Load and compute attention to historical chunks (using selected_blocks)
2. Compute attention to current chunk
3. Merge all results
Note: The BSA-level mask is saved in self._prefill_mask_buffer by select_blocks().
Currently we use flash_attn_with_lse for computation (supports LSE merge).
TODO: Optimize to use BSA kernel with the saved mask for per-head sparse attention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
v: Value tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
layer_id: Current layer index
softmax_scale: Softmax scaling factor
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
current_chunk_idx: Current chunk index
seq: Sequence object
num_tokens: Number of tokens in current chunk
selected_blocks: List of CPU block IDs selected by select_blocks
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
# Use FlashInfer-based implementations (more optimized)
from nanovllm.ops.chunked_attention import (
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
merge_attention_outputs_flashinfer as merge_attention_outputs,
)
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
o_acc = None
lse_acc = None
compute_stream = offload_engine.compute_stream
# Use the pre-selected blocks directly
cpu_block_table = selected_blocks
# Note: BSA mask is available in self._prefill_mask_buffer (saved by select_blocks)
# Mask shape: [1, num_heads, Q_bsa, K_bsa] where Q_bsa = self._current_mask_q_bsa
# Selected indices: self._selected_cpu_indices, bsa_per_cpu: self._bsa_per_cpu
# TODO: Use this mask with BSA kernel for per-head sparse attention optimization
if cpu_block_table:
with nvtx.range("xattn_compute_historical"):
load_slots = list(range(offload_engine.num_ring_slots))
num_blocks = len(cpu_block_table)
if len(load_slots) == 1:
# Only 1 slot - use synchronous mode
slot = load_slots[0]
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
offload_engine.record_slot_compute_done(slot)
else:
# Multiple slots - use pipeline
num_slots = len(load_slots)
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
offload_engine.wait_slot_layer(current_slot)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
offload_engine.record_slot_compute_done(current_slot)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Issue next transfer
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
next_slot = load_slots[next_block_idx % num_slots]
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Compute attention to current chunk (causal mask)
with nvtx.range("xattn_compute_current"):
with torch.cuda.stream(compute_stream):
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
q_batched, k_curr, v_curr,
softmax_scale=softmax_scale,
causal=True,
)
# Merge historical and current attention
with nvtx.range("xattn_compute_merge"):
with torch.cuda.stream(compute_stream):
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
# Sync default stream with compute_stream before returning
torch.cuda.default_stream().wait_stream(compute_stream)
# Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim]
return final_o.squeeze(0)
def compute_chunked_decode(
self,
q: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor:
"""
XAttention does not support decode phase.
"""
raise NotImplementedError(
"XAttentionBSAPolicy does not support decode phase. "
"Use FullAttentionPolicy for decode."
)
def reset(self) -> None:
"""Reset policy state and clear sparse metadata."""
self.sparse_metadata.clear()
# Don't reset statistics here - they accumulate across the entire prefill
def reset_stats(self) -> None:
"""Reset density statistics."""
self._stats_total_available_blocks = 0
self._stats_total_selected_blocks = 0
self._stats_num_chunks = 0
def get_density_stats(self) -> dict:
"""Get density statistics."""
if self._stats_total_available_blocks == 0:
return {
"total_available_blocks": 0,
"total_selected_blocks": 0,
"num_chunks": 0,
"overall_density": 0.0,
}
return {
"total_available_blocks": self._stats_total_available_blocks,
"total_selected_blocks": self._stats_total_selected_blocks,
"num_chunks": self._stats_num_chunks,
"overall_density": self._stats_total_selected_blocks / self._stats_total_available_blocks,
}
def print_density_stats(self) -> None:
"""Print density statistics summary."""
stats = self.get_density_stats()
logger.info(f"[XAttn BSA] Density Stats: chunks={stats['num_chunks']}, "
f"available={stats['total_available_blocks']}, "
f"selected={stats['total_selected_blocks']}, "
f"density={stats['overall_density']:.1%}")
def __repr__(self) -> str:
return f"XAttentionBSAPolicy(threshold={self.threshold}, stride={self.stride})"