Files
nano-vllm/nanovllm/kvcache/sparse/xattn_bsa.py
Zijie Tian 8ab53e7331 🚧 WIP: add DEBUG code for XAttention KV chunking density verification
Add instrumentation to compare GPU-only vs Offload mode density:
- Layer 0 DEBUG output for both modes
- Accumulate selected/total counts across chunks
- Proper causal mask with Q offset handling
- Skip normal offload logic for isolated testing

Test results (threshold=1.0 achieves alignment):
- 32K: GPU-only 0.9999, Offload 0.9999 (diff ~0%)
- 64K: GPU-only 0.9995, Offload 0.9995 (diff ~0%)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
2026-02-01 17:33:23 +08:00

1023 lines
45 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
XAttention Block Sparse Attention (BSA) Policy for nano-vllm.
This module implements XAttention-inspired block sparse attention for chunked prefill.
Key design:
1. Use xattn_estimate_chunked to estimate sparse block mask
2. Use BSA kernel for efficient sparse attention computation
3. Support chunked prefill with q_start_pos for correct position handling
Note: Decode phase is not supported - use FullAttentionPolicy for decode.
"""
import logging
import torch
import torch.cuda.nvtx as nvtx
from typing import List, Tuple, TYPE_CHECKING
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.utils.density_observer import DensityObserver
if TYPE_CHECKING:
from nanovllm.kvcache.offload_engine import OffloadEngine
from nanovllm.kvcache.manager import KVCacheManager
from nanovllm.engine.sequence import Sequence
logger = logging.getLogger(__name__)
# Check BSA availability
try:
from block_sparse_attn import block_sparse_attn_func
BSA_AVAILABLE = True
except ImportError:
BSA_AVAILABLE = False
logger.warning("block_sparse_attn not available, XAttentionBSAPolicy will fallback to dense")
# Check xattn_estimate_chunked availability
try:
from nanovllm.ops.xattn import xattn_estimate_chunked
XATTN_AVAILABLE = True
except ImportError:
XATTN_AVAILABLE = False
logger.warning("xattn_estimate_chunked not available")
def expand_kv_for_gqa(
key_states: torch.Tensor,
value_states: torch.Tensor,
num_heads: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Expand KV for Grouped Query Attention.
Args:
key_states: [B, num_kv_heads, seq_len, head_dim]
value_states: [B, num_kv_heads, seq_len, head_dim]
num_heads: Number of query heads
Returns:
Expanded (key, value) with shape [B, num_heads, seq_len, head_dim]
"""
num_kv_heads = key_states.shape[1]
if num_heads == num_kv_heads:
return key_states, value_states
num_groups = num_heads // num_kv_heads
return (
key_states.repeat_interleave(num_groups, dim=1),
value_states.repeat_interleave(num_groups, dim=1),
)
class XAttentionBSAPolicy(SparsePolicy):
"""
XAttention Block Sparse Attention policy for chunked prefill.
Uses xattn_estimate_chunked to estimate sparse mask, then BSA kernel
for efficient sparse attention computation.
Note:
- Only supports prefill phase (decode uses FullAttentionPolicy)
- BSA block size is fixed at 128 tokens
"""
supports_prefill = True
supports_decode = False # Decode uses FullAttentionPolicy
requires_block_selection = False # Selection happens internally
# BSA requires 128-token blocks
BSA_BLOCK_SIZE = 128
def __init__(
self,
threshold: float = 0.95, # High threshold for accuracy testing
stride: int = 8,
chunk_size: int = 16384,
block_size: int = 128,
samples_per_chunk: int = 128,
use_triton: bool = True,
estimate_block_size: int = 1024, # Optimized block size for softmax_fuse_block_sum
):
"""
Initialize XAttention BSA policy.
Args:
threshold: Cumulative attention threshold for block selection (0-1)
Higher values = more blocks selected = less sparse
stride: Stride for Q/K reshape in estimation (typically 8)
chunk_size: Processing chunk size for xattn_estimate (Triton alignment)
block_size: BSA block size (must be 128)
samples_per_chunk: Samples per chunk for estimation (unused)
use_triton: Whether to use Triton kernels
estimate_block_size: Block size for softmax_fuse_block_sum in select_blocks.
Default 1024 is optimal (15x faster than 4096).
Must be a factor of cpu_block_size (e.g., 4096/1024=4).
"""
self.threshold = threshold
self.stride = stride
self.chunk_size = chunk_size
self.use_triton = use_triton
self.estimate_block_size = estimate_block_size
self._num_heads = None # Set during first forward
# Sparse metadata: stores attention scores per layer
# Dict[layer_id, Tensor[num_q_blocks, num_k_blocks]]
self.sparse_metadata: dict = {}
# Statistics for density tracking
self._stats_total_available_blocks = 0
self._stats_total_selected_blocks = 0
self._stats_num_chunks = 0
# Pre-allocated GQA expansion buffers (GPU-only mode)
# Set by alloc_policy_metadata(), None if not pre-allocated
self._k_expanded: torch.Tensor | None = None
self._v_expanded: torch.Tensor | None = None
self._max_seq_len: int = 0
# Pre-allocated mask buffer for chunked prefill (offload mode)
# Stores BSA-level mask from select_blocks for use in compute_chunked_prefill
# Shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
self._prefill_mask_buffer: torch.Tensor | None = None
self._current_mask_q_bsa: int = 0 # Current Q BSA blocks in buffer
self._current_mask_k_bsa: int = 0 # Current K BSA blocks in buffer
# Selected block indices for mask extraction in compute_chunked_prefill
# Stores the indices of selected CPU blocks in available_blocks
self._selected_cpu_indices: List[int] = []
self._bsa_per_cpu: int = 0 # BSA blocks per CPU block
#> Debug: store all K cache and density counts
self._debug_k_full: torch.Tensor | None = None
self._debug_selected: int = 0 # 累积的 selected blocks
self._debug_total: int = 0 # 累积的 total blocks
def alloc_policy_metadata(
self,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
) -> None:
"""
Pre-allocate GQA expansion buffers for GPU-only mode.
These buffers are used by compute_prefill() to avoid dynamic allocation
during forward pass. The buffers are sized for max_seq_len and sliced
to actual seq_len during use.
Memory usage: 2 * num_heads * max_seq_len * head_dim * dtype_size
For 64K seq, 32 heads, 128 dim, fp16: 2 * 32 * 65536 * 128 * 2 = 1 GB
Args:
num_heads: Number of query heads
num_kv_heads: Number of KV heads (for GQA)
head_dim: Dimension per head
max_seq_len: Maximum sequence length
dtype: Data type
device: Target device
"""
# Pre-allocate mask buffer for chunked prefill (offload mode)
# mask shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
# This is needed regardless of GQA
max_q_bsa_blocks = self.chunk_size // self.BSA_BLOCK_SIZE
max_k_bsa_blocks = max_seq_len // self.BSA_BLOCK_SIZE
mask_shape = (1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks)
self._prefill_mask_buffer = torch.empty(mask_shape, dtype=torch.bool, device=device)
mask_memory_mb = num_heads * max_q_bsa_blocks * max_k_bsa_blocks / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated mask buffer: shape={mask_shape}, memory={mask_memory_mb:.1f} MB")
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads)
if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
return
# Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format
# Also used for BSA which expects [seq_len, num_heads, head_dim]
shape = (1, num_heads, max_seq_len, head_dim)
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
self._max_seq_len = max_seq_len
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
#DEBUG : buffer for save all K cache
self._debug_k_full = torch.empty((1, num_heads, max_seq_len, head_dim), dtype=dtype, device=device)
self._debug_selected = 0
self._debug_total = 0
# =========================================================================
# GPU-only methods (non-chunked)
# =========================================================================
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
layer_id: int,
block_tables: torch.Tensor = None,
) -> torch.Tensor:
"""
GPU-only prefill attention using XAttention + BSA.
This method implements sparse attention for GPU-only mode:
1. Estimate block importance using xattn_estimate
2. Compute sparse attention using block_sparse_attn_func
Args:
q: Query tensor [total_q, num_heads, head_dim] (varlen packed)
k: Key tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
v: Value tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
max_seqlen_q: Maximum Q sequence length
max_seqlen_k: Maximum K sequence length
softmax_scale: Softmax scaling factor
layer_id: Transformer layer index
block_tables: Paged attention block tables (not used for XAttention)
Returns:
Attention output [total_q, num_heads, head_dim]
"""
# When block_tables is provided (paged KV cache / prefix cache),
# fallback to flash_attn as XAttention expects contiguous K, V
if block_tables is not None:
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
block_table=block_tables,
)
if not BSA_AVAILABLE:
# Fallback to flash attention if BSA not available
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
)
if not XATTN_AVAILABLE:
# Fallback to flash attention if xattn not available
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
)
from nanovllm.ops.xattn import xattn_estimate
# Set DensityObserver mode on first layer
if layer_id == 0:
DensityObserver.set_mode("gpu_only")
# Get dimensions
total_q, num_heads, head_dim = q.shape
total_kv, num_kv_heads, _ = k.shape
# For now, assume batch_size = 1 (single sequence)
# TODO: Support batched varlen format
batch_size = cu_seqlens_q.shape[0] - 1
if batch_size != 1:
# Fallback to flash attention for batched input
from flash_attn import flash_attn_varlen_func
logger.warning(f"[XAttn] batch_size={batch_size} > 1, falling back to flash attention")
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
)
q_len = max_seqlen_q
k_len = max_seqlen_k
# Convert from varlen format [total, heads, dim] to [batch, heads, seq, dim]
# q: [q_len, num_heads, head_dim] -> [1, num_heads, q_len, head_dim]
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, q_len, head_dim]
K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
# Expand KV for GQA - use pre-allocated buffers if available
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
if self._k_expanded is not None and k_len <= self._max_seq_len:
# Use pre-allocated buffers with in-place expansion
K_exp = self._k_expanded[:, :, :k_len, :]
V_exp = self._v_expanded[:, :, :k_len, :]
# In-place GQA expansion: [1, num_kv_heads, k_len, head_dim] -> [1, num_heads, k_len, head_dim]
# Reshape K to [1, num_kv_heads, 1, k_len, head_dim] and broadcast to [1, num_kv_heads, num_groups, k_len, head_dim]
K_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
K.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
)
V_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
V.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
)
else:
# Fallback: dynamic allocation (when buffers not pre-allocated or seq too long)
K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads)
else:
K_exp, V_exp = K, V
# Estimate block importance and get sparse mask
with nvtx.range("xattn_estimate"):
_, mask = xattn_estimate(
Q, K_exp,
chunk_size=self.chunk_size,
block_size=self.BSA_BLOCK_SIZE,
stride=self.stride,
threshold=self.threshold,
use_triton=self.use_triton,
causal=True,
)
# Compute block counts
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
k_block_num = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
# Prepare tensors for BSA
# q, k, v need to be [seq_len, num_heads, head_dim]
q_bsa = q # Already [q_len, num_heads, head_dim]
# For GQA with BSA, reuse the expanded K_exp, V_exp (convert to BSA format)
# K_exp: [1, num_heads, k_len, head_dim] -> [k_len, num_heads, head_dim]
if num_heads != num_kv_heads:
k_bsa = K_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
v_bsa = V_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
else:
k_bsa = k
v_bsa = v
# Prepare BSA inputs
cu_seqlens_q_bsa = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
cu_seqlens_k_bsa = torch.tensor([0, k_len], dtype=torch.int32, device=k.device)
head_groups = torch.ones(num_heads, dtype=torch.int32, device=q.device)
# Trim mask to actual block counts
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
# Compute sparse attention using BSA
with nvtx.range("xattn_bsa_compute"):
output = block_sparse_attn_func(
q_bsa, k_bsa, v_bsa,
cu_seqlens_q_bsa,
cu_seqlens_k_bsa,
head_groups,
None, # key_padding_mask
mask_trimmed,
q_len, k_len,
p_dropout=0.0,
deterministic=True,
is_causal=True,
)
# Record density for all layers via DensityObserver
if layer_id == 0:
# DEBUG: 打印 GPU-only Layer 0 的 mask 详情
q_bk = mask_trimmed.shape[2]
k_bk = mask_trimmed.shape[3]
causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1]
causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool))
selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
logger.info(f"[DEBUG GPU-only Layer0] mask_shape={mask_trimmed.shape}, "
f"density={selected/causal_total:.6f}, selected={selected}, total={causal_total}")
DensityObserver.record(layer_id, mask_trimmed, causal=True)
return output
def compute_decode(
self,
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
cache_seqlens: torch.Tensor,
softmax_scale: float,
layer_id: int,
block_tables: torch.Tensor = None,
) -> torch.Tensor:
"""
GPU-only decode attention - delegates to FullAttentionPolicy.
XAttention is designed for long prefill sequences. For decode (single token),
we use FullAttentionPolicy which calls flash_attn_with_kvcache.
"""
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
return FullAttentionPolicy().compute_decode(
q, k_cache, v_cache, cache_seqlens, softmax_scale, layer_id, block_tables
)
# =========================================================================
# Chunked offload methods
# =========================================================================
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]:
"""
Compute attention scores for all available blocks using flat_group_gemm,
then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks.
This method aligns with GPU-only xattn_estimate_chunked:
1. Loads each K block from CPU (historical blocks)
2. Gets current chunk K from prefill buffer
3. Concatenates [historical K, current chunk K] for correct softmax normalization
4. Uses causal=True with correct chunk_start for position-aware masking
5. Only selects from historical blocks (current chunk is always full attention)
Args:
available_blocks: List of CPU block IDs (historical blocks only)
offload_engine: OffloadEngine for loading blocks
ctx: PolicyContext with metadata
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk (used for estimation)
Returns:
Selected block IDs based on attention threshold
"""
if q is None:
return available_blocks
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked
import math
layer_id = ctx.layer_id
# Use passed q parameter instead of ctx.query
# Set DensityObserver mode on first layer
if layer_id == 0:
DensityObserver.set_mode("offload")
# Convert Q to [batch, heads, seq_len, head_dim]
# q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim]
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
num_heads = Q.shape[1]
head_dim = Q.shape[3]
q_len = Q.shape[2]
# flat_group_gemm requires q_len to be divisible by stride * BLOCK_M (typically 8 * 128 = 1024)
# Pad Q if necessary
BLOCK_M = 128 # Triton block size
alignment = self.stride * BLOCK_M
if q_len < alignment:
# Q too short, skip estimation and return all blocks
logger.debug(f"[XAttn] select_blocks: q_len={q_len} < alignment={alignment}, skipping estimation")
return available_blocks
# Pad Q to alignment
padded_q_len = ((q_len + alignment - 1) // alignment) * alignment
if padded_q_len != q_len:
pad_size = padded_q_len - q_len
Q = torch.nn.functional.pad(Q, (0, 0, 0, pad_size), value=0)
q_reshaped_len = padded_q_len // self.stride
# Get block size from context
block_size = ctx.block_size # tokens per CPU block (e.g., 4096)
reshaped_block_size = block_size // self.stride # e.g., 4096/8 = 512
# ============================================================
# Step 1: Compute chunk_start and related parameters
# ============================================================
# chunk_start = Q's global position in reshaped space
# Q starts at position: num_historical_blocks * block_size
num_historical_blocks = len(available_blocks)
historical_k_len = num_historical_blocks * block_size
chunk_start = historical_k_len // self.stride # Q's position in reshaped space
chunk_end = chunk_start + q_reshaped_len
# For valid Q length tracking (excluding padding)
valid_q_reshaped = (q_len + self.stride - 1) // self.stride
real_q_len = chunk_start + valid_q_reshaped
# ============================================================
# Step 2: Pipeline load historical K blocks and compute attn_scores
# ============================================================
# Key design: Load each block, compute immediately, then release
# This avoids storing all K in GPU memory at once (offload-friendly)
slot = 0
attn_scores_list = []
BLOCK_N = 128
k_alignment = self.stride * BLOCK_N
with nvtx.range("xattn_estimate_historical"):
for cpu_block_id in available_blocks:
# Load only K from CPU to GPU (V not needed for estimate)
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
# Get K only: [1, block_size, num_kv_heads, head_dim]
k_block = offload_engine.get_k_for_slot(slot)
# Convert K to [batch, heads, k_len, head_dim]
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
# Handle GQA: expand K heads to match Q heads
num_kv_heads = K_chunk.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
#> DEBUG: save all K cache
start_pos = cpu_block_id * block_size
self._debug_k_full[:, :, start_pos:start_pos + block_size, :].copy_(K_chunk)
# # Pad K if necessary
# k_len = K_chunk.shape[2]
# if k_len < k_alignment:
# pad_size = k_alignment - k_len
# K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
# # Compute attention scores for this historical block
# # Historical blocks: all positions < Q, so Q always sees them (full attention)
# # Use LOCAL chunk_start=0 to match test_xattn_k_chunked.py behavior
# attn_chunk = flat_group_gemm_fuse_reshape(
# Q, K_chunk, self.stride,
# chunk_start=0, # Local: same as test
# chunk_end=q_reshaped_len,
# is_causal=False, # Historical K: all visible to Q
# )
# attn_scores_list.append(attn_chunk)
# Mark slot as done for reuse
offload_engine.record_slot_compute_done(slot)
num_kv_heads = k.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
k_repeated = k.repeat_interleave(num_groups, dim=1).unsqueeze(0).transpose(1, 2) # [1, num_heads, historical_k_len, head_dim]
self._debug_k_full[:, :, historical_k_len:historical_k_len + q_len, :].copy_(k_repeated)
# ============================================================
# DEBUG: 累积 selected/total counts (仅 layer 0)
# 使用完整 K 调用 xattn_estimate与 GPU-only 逻辑一致
# ============================================================
if layer_id == 0:
from nanovllm.ops.xattn import xattn_estimate
total_k_len = historical_k_len + q_len
K_full = self._debug_k_full[:, :, :total_k_len, :]
# 用当前 Q chunk 和累积的 K 调用 xattn_estimate
# 设置 chunk_size 为 q_len 的最小对齐值 (stride * BLOCK_M = 8 * 128 = 1024)
alignment = self.stride * 128
aligned_chunk_size = ((q_len + alignment - 1) // alignment) * alignment
# DEBUG: 使用固定 threshold 测试
_, mask_chunk = xattn_estimate(
Q[:, :, :q_len, :], # 当前 Q chunk
K_full, # 累积的 K
block_size=self.BSA_BLOCK_SIZE,
stride=self.stride,
threshold=self.threshold, # DEBUG: 使用传入的 threshold
chunk_size=aligned_chunk_size, # 对齐的 chunk_size
causal=True,
)
# 计算有效的 block 数量(排除 padding
valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
valid_k_blocks = (total_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
# 裁剪 mask 到有效区域
mask_valid = mask_chunk[:, :, :valid_q_blocks, :valid_k_blocks]
# 计算当前 chunk 的 selected/total (考虑 causal考虑 Q 偏移量)
q_blocks = valid_q_blocks
k_blocks = valid_k_blocks
# Q 从位置 (k_blocks - q_blocks) 开始,所以 Q block i 实际位置是 i + offset
# Q block i (实际位置 i+offset) 可以看到 K block 0 到 i+offset
q_offset_blocks = k_blocks - q_blocks
indices = torch.arange(k_blocks, device=mask_valid.device).unsqueeze(0) # [1, k_blocks]
q_indices = torch.arange(q_blocks, device=mask_valid.device).unsqueeze(1) # [q_blocks, 1]
causal_mask = indices <= (q_indices + q_offset_blocks) # [q_blocks, k_blocks]
chunk_total = causal_mask.sum().item() * mask_valid.shape[0] * mask_valid.shape[1]
chunk_selected = (mask_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
# 累积
self._debug_selected += chunk_selected
self._debug_total += chunk_total
# 打印当前累积的 density
if self._debug_total > 0:
density = self._debug_selected / self._debug_total
logger.info(f"[DEBUG Offload Layer0] 累积 density: {density:.4f} "
f"(selected={self._debug_selected}, total={self._debug_total}, k_len={total_k_len}, "
f"mask_shape={mask_chunk.shape}, q_offset={q_offset_blocks})")
# DEBUG: 跳过正常 offload 逻辑,直接返回所有 blocks
return available_blocks
else:
# DEBUG: 非 Layer 0 也跳过正常 offload 逻辑
return available_blocks
# ============================================================
# Step 3: Get current chunk K and compute its attn_scores
# ============================================================
with nvtx.range("xattn_estimate_current"):
# Current chunk K is in prefill buffer (already on GPU)
k_curr, _ = offload_engine.get_prefill_buffer_slice(layer_id, q_len)
# k_curr: [1, q_len, num_kv_heads, head_dim] -> [1, num_kv_heads, q_len, head_dim]
K_current = k_curr.transpose(1, 2)
# Handle GQA for current chunk K
num_kv_heads = K_current.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_current = K_current.repeat_interleave(num_groups, dim=1)
# Pad current K if necessary
curr_k_len = K_current.shape[2]
padded_curr_k_len = ((curr_k_len + k_alignment - 1) // k_alignment) * k_alignment
if padded_curr_k_len != curr_k_len:
pad_size = padded_curr_k_len - curr_k_len
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, pad_size), value=0)
# Compute attention scores for current chunk
# IMPORTANT: Use LOCAL coordinates (0 to q_reshaped_len) for current chunk!
# Because K_current only contains current chunk K (not full sequence),
# block_n in kernel starts from 0. Using global chunk_start would cause
# incorrect causal mask (Q would see K blocks it shouldn't).
attn_current = flat_group_gemm_fuse_reshape(
Q, K_current, self.stride,
chunk_start=0, # Local: Q starts at 0 relative to K_current
chunk_end=q_reshaped_len, # Local: Q ends at q_reshaped_len
is_causal=True, # Current chunk: apply causal mask
)
attn_scores_list.append(attn_current)
del K_current
# ============================================================
# Step 4: Concatenate all attn_scores
# ============================================================
if not attn_scores_list:
return available_blocks
attn_scores = torch.cat(attn_scores_list, dim=-1)
del attn_scores_list
# Calculate padded K length for later use
padded_k_len = historical_k_len + padded_curr_k_len
# ============================================================
# Step 5: Apply softmax_fuse_block_sum with causal=True
# ============================================================
cpu_block_size = block_size # e.g., 4096
bsa_per_cpu = cpu_block_size // self.BSA_BLOCK_SIZE # e.g., 4096/128 = 32
# Use BSA_BLOCK_SIZE for block aggregation (aligned with GPU-only)
reshaped_bsa_bs = self.BSA_BLOCK_SIZE // self.stride # e.g., 128/8 = 16
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm
segment_size = min(4096, reshaped_bsa_bs)
with nvtx.range("xattn_estimate_softmax"):
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_bsa_bs,
segment_size,
chunk_start=chunk_start,
chunk_end=chunk_end,
real_q_len=real_q_len,
scale=scale,
is_causal=True, # Causal for consistent with GPU-only
)
# block_sums shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
# ============================================================
# Step 6: Use find_blocks_chunked to generate BSA-level mask
# ============================================================
# Calculate BSA block indices
q_bsa_blocks = (padded_q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
total_k_bsa_blocks = (padded_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
historical_k_bsa_blocks = num_historical_blocks * bsa_per_cpu
# current_index for find_blocks_chunked: Q's block offset
q_start_bsa_block = historical_k_bsa_blocks # Q starts after historical K
with nvtx.range("xattn_find_blocks"):
# 对于历史 K 的选择,使用 causal=False 因为历史 K 都在当前 Q 之前
# current_index=0 避免超出 block_sums 的 K 维度
mask = find_blocks_chunked(
block_sums,
current_index=0,
threshold=self.threshold,
num_to_choose=None,
decoding=False,
mode="both",
causal=False,
)
# mask shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
# ============================================================
# Step 7: Extract mask portions and record density
# ============================================================
B, H, Q_bsa, K_bsa_total = mask.shape
# Calculate valid Q blocks (excluding padding)
valid_q_bsa = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
valid_curr_k_bsa = (curr_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
# 7a: Record historical blocks density (暂时禁用,使用 DEBUG 输出代替)
# if historical_k_bsa_blocks > 0:
# ... DensityObserver.record_counts ...
# 7b: Record current chunk density (暂时禁用)
# if valid_curr_k_bsa > 0:
# ... DensityObserver.record_counts ...
# Step 7.5: Save historical mask to pre-allocated buffer for compute_chunked_prefill
# Use full Q_bsa (padded) for buffer, not valid_q_bsa
mask_historical_full = mask[:, :, :, :historical_k_bsa_blocks]
if self._prefill_mask_buffer is not None:
# Only save historical portion of mask
self._prefill_mask_buffer[:, :, :Q_bsa, :historical_k_bsa_blocks].copy_(mask_historical_full)
self._current_mask_q_bsa = Q_bsa
self._current_mask_k_bsa = historical_k_bsa_blocks
# ============================================================
# Step 8: Aggregate mask to CPU block level (union of heads)
# ============================================================
# Only aggregate historical blocks (current chunk is always full attention)
num_cpu_blocks = num_historical_blocks
with nvtx.range("xattn_aggregate_mask"):
# Reshape historical mask: [B, H, Q_bsa, historical_k_bsa] -> [B, H, Q_bsa, num_cpu, bsa_per_cpu]
# Use full Q_bsa (not valid_q_bsa) for aggregation
mask_per_cpu = mask_historical_full.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu)
# Union across: bsa_per_cpu, Q_bsa, heads -> [B, num_cpu]
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1) # [B, num_cpu]
# Get selected indices
selected_indices = cpu_needed[0].nonzero().squeeze(-1).tolist()
if isinstance(selected_indices, int):
selected_indices = [selected_indices]
# Handle empty available_blocks case (first chunk)
if available_blocks:
selected_block_ids = [available_blocks[i] for i in selected_indices]
else:
selected_block_ids = []
# Always include first block (sink) and last block for safety
if available_blocks and available_blocks[0] not in selected_block_ids:
selected_block_ids.insert(0, available_blocks[0])
if available_blocks and available_blocks[-1] not in selected_block_ids:
selected_block_ids.append(available_blocks[-1])
# Record communication density (CPU block granularity) - only if there are historical blocks
if available_blocks:
DensityObserver.record_comm_density(
layer_id,
selected_cpu_blocks=len(selected_block_ids),
total_cpu_blocks=len(available_blocks),
)
# Update statistics (only for layer 0 to avoid overcounting)
if layer_id == 0 and available_blocks:
self._stats_total_available_blocks += len(available_blocks)
self._stats_total_selected_blocks += len(selected_block_ids)
self._stats_num_chunks += 1
# Log per-chunk density
chunk_density = len(selected_block_ids) / len(available_blocks)
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, "
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
# Free intermediate tensors to prevent memory leak
del attn_scores, block_sums, mask, mask_historical_full
return selected_block_ids
def compute_chunked_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
current_chunk_idx: int,
seq: "Sequence",
num_tokens: int,
selected_blocks: List[int],
) -> torch.Tensor:
"""
Compute attention for chunked prefill using XAttention sparse block selection.
This method handles the chunked prefill computation:
1. Load and compute attention to historical chunks (using selected_blocks)
2. Compute attention to current chunk
3. Merge all results
Note: The BSA-level mask is saved in self._prefill_mask_buffer by select_blocks().
Currently we use flash_attn_with_lse for computation (supports LSE merge).
TODO: Optimize to use BSA kernel with the saved mask for per-head sparse attention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
v: Value tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
layer_id: Current layer index
softmax_scale: Softmax scaling factor
offload_engine: OffloadEngine for loading blocks
kvcache_manager: KVCacheManager for block management
current_chunk_idx: Current chunk index
seq: Sequence object
num_tokens: Number of tokens in current chunk
selected_blocks: List of CPU block IDs selected by select_blocks
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
# Use FlashInfer-based implementations (more optimized)
from nanovllm.ops.chunked_attention import (
flash_attn_with_lse_flashinfer as flash_attn_with_lse,
merge_attention_outputs_flashinfer as merge_attention_outputs,
)
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
o_acc = None
lse_acc = None
compute_stream = offload_engine.compute_stream
# Use the pre-selected blocks directly
cpu_block_table = selected_blocks
# Note: BSA mask is available in self._prefill_mask_buffer (saved by select_blocks)
# Mask shape: [1, num_heads, Q_bsa, K_bsa] where Q_bsa = self._current_mask_q_bsa
# Selected indices: self._selected_cpu_indices, bsa_per_cpu: self._bsa_per_cpu
# TODO: Use this mask with BSA kernel for per-head sparse attention optimization
if cpu_block_table:
with nvtx.range("xattn_compute_historical"):
load_slots = list(range(offload_engine.num_ring_slots))
num_blocks = len(cpu_block_table)
if len(load_slots) == 1:
# Only 1 slot - use synchronous mode
slot = load_slots[0]
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
offload_engine.record_slot_compute_done(slot)
else:
# Multiple slots - use pipeline
num_slots = len(load_slots)
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
cpu_block_id = cpu_block_table[i]
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_id, chunk_idx=cpu_block_id)
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
offload_engine.wait_slot_layer(current_slot)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=softmax_scale,
causal=False,
)
offload_engine.record_slot_compute_done(current_slot)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Issue next transfer
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
next_slot = load_slots[next_block_idx % num_slots]
next_cpu_block_id = cpu_block_table[next_block_idx]
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id, chunk_idx=next_cpu_block_id)
# Compute attention to current chunk (causal mask)
with nvtx.range("xattn_compute_current"):
with torch.cuda.stream(compute_stream):
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
current_o, current_lse = flash_attn_with_lse(
q_batched, k_curr, v_curr,
softmax_scale=softmax_scale,
causal=True,
)
# Merge historical and current attention
with nvtx.range("xattn_compute_merge"):
with torch.cuda.stream(compute_stream):
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
# Sync default stream with compute_stream before returning
torch.cuda.default_stream().wait_stream(compute_stream)
# Remove batch dimension: [1, seq_len, num_heads, head_dim] -> [seq_len, num_heads, head_dim]
return final_o.squeeze(0)
def compute_chunked_decode(
self,
q: torch.Tensor,
layer_id: int,
softmax_scale: float,
offload_engine: "OffloadEngine",
kvcache_manager: "KVCacheManager",
seq: "Sequence",
selected_blocks: List[int],
) -> torch.Tensor:
"""
XAttention does not support decode phase.
"""
raise NotImplementedError(
"XAttentionBSAPolicy does not support decode phase. "
"Use FullAttentionPolicy for decode."
)
def reset(self) -> None:
"""Reset policy state and clear sparse metadata."""
self.sparse_metadata.clear()
# Don't reset statistics here - they accumulate across the entire prefill
def reset_stats(self) -> None:
"""Reset density statistics."""
self._stats_total_available_blocks = 0
self._stats_total_selected_blocks = 0
self._stats_num_chunks = 0
def get_density_stats(self) -> dict:
"""Get density statistics."""
if self._stats_total_available_blocks == 0:
return {
"total_available_blocks": 0,
"total_selected_blocks": 0,
"num_chunks": 0,
"overall_density": 0.0,
}
return {
"total_available_blocks": self._stats_total_available_blocks,
"total_selected_blocks": self._stats_total_selected_blocks,
"num_chunks": self._stats_num_chunks,
"overall_density": self._stats_total_selected_blocks / self._stats_total_available_blocks,
}
def print_density_stats(self) -> None:
"""Print density statistics summary."""
stats = self.get_density_stats()
logger.info(f"[XAttn BSA] Density Stats: chunks={stats['num_chunks']}, "
f"available={stats['total_available_blocks']}, "
f"selected={stats['total_selected_blocks']}, "
f"density={stats['overall_density']:.1%}")
def __repr__(self) -> str:
return f"XAttentionBSAPolicy(threshold={self.threshold}, stride={self.stride})"