WIP: Enhance sparse attention with density tracking and block selection improvements

- Added analysis documentation for xattn density alignment.
- Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration.
- Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection.
- Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks.
- Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling.
- Introduced DensityObserver to track compute and communication density for sparse attention layers.
- Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios.
- Added tests for attention kernel behavior with enhanced input patterns.
This commit is contained in:
Zijie Tian
2026-01-31 14:48:23 +08:00
parent f6ac4ccdde
commit 2e96d1d97d
9 changed files with 490 additions and 152 deletions

View File

@@ -37,6 +37,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
| [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x | | [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x |
| [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL)≤10B 推荐模型 | | [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL)≤10B 推荐模型 |
| [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 | | [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 |
| [`docs/xattn_density_alignment_analysis.md`](docs/xattn_density_alignment_analysis.md) | 📊 ANALYSIS: GPU-only vs Offload 模式 density 对齐分析chunked softmax 边界效应5-7% 差异根因 |
## Rules Index ## Rules Index

View File

@@ -229,7 +229,7 @@ class ModelRunner:
# GPU-only mode: pre-allocate policy metadata buffers # GPU-only mode: pre-allocate policy metadata buffers
# This avoids dynamic GPU memory allocation during forward pass # This avoids dynamic GPU memory allocation during forward pass
if not config.enable_cpu_offload: # if not config.enable_cpu_offload:
num_heads = hf_config.num_attention_heads // self.world_size num_heads = hf_config.num_attention_heads // self.world_size
self.kvcache_manager.sparse_policy.alloc_policy_metadata( self.kvcache_manager.sparse_policy.alloc_policy_metadata(
num_heads=num_heads, num_heads=num_heads,

View File

@@ -47,6 +47,8 @@ class FullAttentionPolicy(SparsePolicy):
available_blocks: List[int], available_blocks: List[int],
offload_engine: "OffloadEngine", offload_engine: "OffloadEngine",
ctx: PolicyContext, ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]: ) -> List[int]:
"""Return all blocks - no sparsity.""" """Return all blocks - no sparsity."""
# Update statistics (only for layer 0 to avoid overcounting) # Update statistics (only for layer 0 to avoid overcounting)

View File

@@ -142,6 +142,8 @@ class SparsePolicy(ABC):
available_blocks: List[int], available_blocks: List[int],
offload_engine: "OffloadEngine", offload_engine: "OffloadEngine",
ctx: PolicyContext, ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]: ) -> List[int]:
""" """
Select which KV blocks to load for the current query chunk. Select which KV blocks to load for the current query chunk.
@@ -158,6 +160,8 @@ class SparsePolicy(ABC):
to load KV to make selection decisions). to load KV to make selection decisions).
ctx: PolicyContext with information about the current query ctx: PolicyContext with information about the current query
chunk, layer, phase (prefill/decode), etc. chunk, layer, phase (prefill/decode), etc.
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk
Returns: Returns:
List of block IDs to load (must be a subset of available_blocks). List of block IDs to load (must be a subset of available_blocks).

View File

@@ -191,13 +191,26 @@ class QuestPolicy(SparsePolicy):
def select_blocks( def select_blocks(
self, self,
available_blocks: List[int], available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext, ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]: ) -> List[int]:
""" """
Select Top-K blocks based on query-key similarity bounds. Select Top-K blocks based on query-key similarity bounds.
If query is not available (some prefill scenarios), falls back If query is not available (some prefill scenarios), falls back
to loading all blocks. to loading all blocks.
Args:
available_blocks: List of CPU block IDs
offload_engine: OffloadEngine for loading KV (unused in Quest)
ctx: PolicyContext with metadata
q: Query tensor [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused in Quest, uses metadata instead)
Returns:
Selected block IDs
""" """
if self.metadata is None: if self.metadata is None:
raise RuntimeError( raise RuntimeError(
@@ -211,7 +224,7 @@ class QuestPolicy(SparsePolicy):
if n <= self.config.threshold_blocks: if n <= self.config.threshold_blocks:
return available_blocks return available_blocks
if ctx.query is None: if q is None:
# No query available - cannot compute scores # No query available - cannot compute scores
return available_blocks return available_blocks
@@ -221,11 +234,10 @@ class QuestPolicy(SparsePolicy):
) )
# Metadata is already on GPU, same device as query # Metadata is already on GPU, same device as query
device = ctx.query.device device = q.device
# Compute upper bound scores # Compute upper bound scores
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim] # query shape: [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
q = ctx.query
if q.dim() == 4: if q.dim() == 4:
# Prefill: use mean over sequence length # Prefill: use mean over sequence length
q = q.mean(dim=1) # [1, num_heads, head_dim] q = q.mean(dim=1) # [1, num_heads, head_dim]

View File

@@ -135,6 +135,21 @@ class XAttentionBSAPolicy(SparsePolicy):
self._v_expanded: torch.Tensor | None = None self._v_expanded: torch.Tensor | None = None
self._max_seq_len: int = 0 self._max_seq_len: int = 0
# Pre-allocated mask buffer for chunked prefill (offload mode)
# Stores BSA-level mask from select_blocks for use in compute_chunked_prefill
# Shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
self._prefill_mask_buffer: torch.Tensor | None = None
self._current_mask_q_bsa: int = 0 # Current Q BSA blocks in buffer
self._current_mask_k_bsa: int = 0 # Current K BSA blocks in buffer
# Selected block indices for mask extraction in compute_chunked_prefill
# Stores the indices of selected CPU blocks in available_blocks
self._selected_cpu_indices: List[int] = []
self._bsa_per_cpu: int = 0 # BSA blocks per CPU block
#> Debug: store all K cache
self._debug_k_full: torch.Tensor | None = None
def alloc_policy_metadata( def alloc_policy_metadata(
self, self,
num_heads: int, num_heads: int,
@@ -162,7 +177,17 @@ class XAttentionBSAPolicy(SparsePolicy):
dtype: Data type dtype: Data type
device: Target device device: Target device
""" """
# Only allocate if GQA (num_heads != num_kv_heads) # Pre-allocate mask buffer for chunked prefill (offload mode)
# mask shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
# This is needed regardless of GQA
max_q_bsa_blocks = self.chunk_size // self.BSA_BLOCK_SIZE
max_k_bsa_blocks = max_seq_len // self.BSA_BLOCK_SIZE
mask_shape = (1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks)
self._prefill_mask_buffer = torch.empty(mask_shape, dtype=torch.bool, device=device)
mask_memory_mb = num_heads * max_q_bsa_blocks * max_k_bsa_blocks / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated mask buffer: shape={mask_shape}, memory={mask_memory_mb:.1f} MB")
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads)
if num_heads == num_kv_heads: if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})") logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
return return
@@ -177,6 +202,9 @@ class XAttentionBSAPolicy(SparsePolicy):
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024) memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB") logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
#DEBUG : buffer for save all K cache.
self._debug_k_full = torch.empty((1, num_heads, max_seq_len, head_dim), dtype=dtype, device=device)
# ========================================================================= # =========================================================================
# GPU-only methods (non-chunked) # GPU-only methods (non-chunked)
# ========================================================================= # =========================================================================
@@ -401,33 +429,42 @@ class XAttentionBSAPolicy(SparsePolicy):
available_blocks: List[int], available_blocks: List[int],
offload_engine: "OffloadEngine", offload_engine: "OffloadEngine",
ctx: PolicyContext, ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]: ) -> List[int]:
""" """
Compute attention scores for all available blocks using flat_group_gemm, Compute attention scores for all available blocks using flat_group_gemm,
then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks. then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks.
This method: This method aligns with GPU-only xattn_estimate_chunked:
1. Loads each K block from CPU 1. Loads each K block from CPU (historical blocks)
2. Computes Q@K^T attention scores using XAttention stride reshape 2. Gets current chunk K from prefill buffer
3. Applies softmax_fuse_block_sum to get block-level attention 3. Concatenates [historical K, current chunk K] for correct softmax normalization
4. Uses find_blocks_chunked to select blocks based on threshold 4. Uses causal=True with correct chunk_start for position-aware masking
5. Only selects from historical blocks (current chunk is always full attention)
Args: Args:
available_blocks: List of CPU block IDs available_blocks: List of CPU block IDs (historical blocks only)
offload_engine: OffloadEngine for loading blocks offload_engine: OffloadEngine for loading blocks
ctx: PolicyContext with query tensor and metadata ctx: PolicyContext with metadata
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk (used for estimation)
Returns: Returns:
Selected block IDs based on attention threshold Selected block IDs based on attention threshold
""" """
if not available_blocks or ctx.query is None: if q is None:
return available_blocks return available_blocks
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked
import math import math
layer_id = ctx.layer_id layer_id = ctx.layer_id
q = ctx.query # [seq_len, num_heads, head_dim] # Use passed q parameter instead of ctx.query
# Set DensityObserver mode on first layer
if layer_id == 0:
DensityObserver.set_mode("offload")
# Convert Q to [batch, heads, seq_len, head_dim] # Convert Q to [batch, heads, seq_len, head_dim]
# q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim] # q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim]
@@ -454,18 +491,37 @@ class XAttentionBSAPolicy(SparsePolicy):
q_reshaped_len = padded_q_len // self.stride q_reshaped_len = padded_q_len // self.stride
# Use a single slot for loading (synchronous mode for simplicity) # Get block size from context
block_size = ctx.block_size # tokens per CPU block (e.g., 4096)
reshaped_block_size = block_size // self.stride # e.g., 4096/8 = 512
# ============================================================
# Step 1: Compute chunk_start and related parameters
# ============================================================
# chunk_start = Q's global position in reshaped space
# Q starts at position: num_historical_blocks * block_size
num_historical_blocks = len(available_blocks)
historical_k_len = num_historical_blocks * block_size
chunk_start = historical_k_len // self.stride # Q's position in reshaped space
chunk_end = chunk_start + q_reshaped_len
# For valid Q length tracking (excluding padding)
valid_q_reshaped = (q_len + self.stride - 1) // self.stride
real_q_len = chunk_start + valid_q_reshaped
# ============================================================
# Step 2: Pipeline load historical K blocks and compute attn_scores
# ============================================================
# Key design: Load each block, compute immediately, then release
# This avoids storing all K in GPU memory at once (offload-friendly)
slot = 0 slot = 0
attn_scores_list = [] attn_scores_list = []
BLOCK_N = 128
k_alignment = self.stride * BLOCK_N
# Get block size from context with nvtx.range("xattn_estimate_historical"):
block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
with nvtx.range("xattn_estimate_gemm"):
for cpu_block_id in available_blocks: for cpu_block_id in available_blocks:
# Load only K from CPU to GPU (V not needed for estimate) # Load only K from CPU to GPU (V not needed for estimate)
# This saves 50% communication in the estimate phase
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot) offload_engine.wait_slot_layer(slot)
@@ -473,8 +529,7 @@ class XAttentionBSAPolicy(SparsePolicy):
k_block = offload_engine.get_k_for_slot(slot) k_block = offload_engine.get_k_for_slot(slot)
# Convert K to [batch, heads, k_len, head_dim] # Convert K to [batch, heads, k_len, head_dim]
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim] K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
K_chunk = k_block.transpose(1, 2)
# Handle GQA: expand K heads to match Q heads # Handle GQA: expand K heads to match Q heads
num_kv_heads = K_chunk.shape[1] num_kv_heads = K_chunk.shape[1]
@@ -482,116 +537,220 @@ class XAttentionBSAPolicy(SparsePolicy):
num_groups = num_heads // num_kv_heads num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
# Pad K if necessary (k_len must be divisible by stride * BLOCK_N) #> DEBUG: save all K cache
k_len = K_chunk.shape[2] start_pos = cpu_block_id * block_size
BLOCK_N = 128 self._debug_k_full[:, :, start_pos:start_pos + block_size, :].copy_(K_chunk)
k_alignment = self.stride * BLOCK_N
if k_len < k_alignment:
# K too short, pad it
pad_size = k_alignment - k_len
K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
# Compute attention scores using flat_group_gemm_fuse_reshape # # Pad K if necessary
# Output: [batch, heads, q_len/stride, k_len/stride] # k_len = K_chunk.shape[2]
attn_chunk = flat_group_gemm_fuse_reshape( # if k_len < k_alignment:
Q, K_chunk, self.stride, # pad_size = k_alignment - k_len
chunk_start=0, # K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
chunk_end=q_reshaped_len,
is_causal=False # # Compute attention scores for this historical block
) # # Historical blocks: all positions < Q, so Q always sees them (full attention)
attn_scores_list.append(attn_chunk) # # Use LOCAL chunk_start=0 to match test_xattn_k_chunked.py behavior
# attn_chunk = flat_group_gemm_fuse_reshape(
# Q, K_chunk, self.stride,
# chunk_start=0, # Local: same as test
# chunk_end=q_reshaped_len,
# is_causal=False, # Historical K: all visible to Q
# )
# attn_scores_list.append(attn_chunk)
# Mark slot as done for reuse # Mark slot as done for reuse
offload_engine.record_slot_compute_done(slot) offload_engine.record_slot_compute_done(slot)
# Concatenate all attention scores along K dimension num_kv_heads = k.shape[1]
# Each chunk: [1, heads, q_reshaped_len, block_reshaped_len] if num_heads != num_kv_heads:
# Result: [1, heads, q_reshaped_len, total_k_reshaped_len] num_groups = num_heads // num_kv_heads
k_repeated = k.repeat_interleave(num_groups, dim=1).unsqueeze(0).transpose(1, 2) # [1, num_heads, historical_k_len, head_dim]
self._debug_k_full[:, :, historical_k_len:historical_k_len + q_len, :].copy_(k_repeated)
if layer_id == 0:
__import__('pdb').set_trace()
# ============================================================
# Step 3: Get current chunk K and compute its attn_scores
# ============================================================
with nvtx.range("xattn_estimate_current"):
# Current chunk K is in prefill buffer (already on GPU)
k_curr, _ = offload_engine.get_prefill_buffer_slice(layer_id, q_len)
# k_curr: [1, q_len, num_kv_heads, head_dim] -> [1, num_kv_heads, q_len, head_dim]
K_current = k_curr.transpose(1, 2)
# Handle GQA for current chunk K
num_kv_heads = K_current.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_current = K_current.repeat_interleave(num_groups, dim=1)
# Pad current K if necessary
curr_k_len = K_current.shape[2]
padded_curr_k_len = ((curr_k_len + k_alignment - 1) // k_alignment) * k_alignment
if padded_curr_k_len != curr_k_len:
pad_size = padded_curr_k_len - curr_k_len
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, pad_size), value=0)
# Compute attention scores for current chunk
# IMPORTANT: Use LOCAL coordinates (0 to q_reshaped_len) for current chunk!
# Because K_current only contains current chunk K (not full sequence),
# block_n in kernel starts from 0. Using global chunk_start would cause
# incorrect causal mask (Q would see K blocks it shouldn't).
attn_current = flat_group_gemm_fuse_reshape(
Q, K_current, self.stride,
chunk_start=0, # Local: Q starts at 0 relative to K_current
chunk_end=q_reshaped_len, # Local: Q ends at q_reshaped_len
is_causal=True, # Current chunk: apply causal mask
)
attn_scores_list.append(attn_current)
del K_current
# ============================================================
# Step 4: Concatenate all attn_scores
# ============================================================
if not attn_scores_list: if not attn_scores_list:
return available_blocks return available_blocks
attn_scores = torch.cat(attn_scores_list, dim=-1) attn_scores = torch.cat(attn_scores_list, dim=-1)
# Free intermediate list immediately
del attn_scores_list del attn_scores_list
# Step 2: Apply softmax_fuse_block_sum with hierarchical aggregation # Calculate padded K length for later use
# Use smaller estimate_block_size (1024) for 15x faster softmax kernel, padded_k_len = historical_k_len + padded_curr_k_len
# then aggregate to CPU block level (4096).
#
# Hierarchical approach:
# 1. softmax_fuse_block_sum with estimate_block_size (1024) -> fine-grained scores
# 2. Aggregate: reshape + sum -> CPU block level scores
# 3. Select blocks based on score + threshold (NOT mask + voting)
cpu_block_size = block_size # e.g., 4096
estimate_bs = self.estimate_block_size # e.g., 1024 (15x faster)
ratio = cpu_block_size // estimate_bs # e.g., 4
# Use estimate_block_size for softmax kernel (optimized) # ============================================================
reshaped_est_bs = estimate_bs // self.stride # e.g., 1024/8 = 128 # Step 5: Apply softmax_fuse_block_sum with causal=True
norm = 1.0 # Normalization factor # ============================================================
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling cpu_block_size = block_size # e.g., 4096
segment_size = min(4096, reshaped_est_bs) bsa_per_cpu = cpu_block_size // self.BSA_BLOCK_SIZE # e.g., 4096/128 = 32
# Use BSA_BLOCK_SIZE for block aggregation (aligned with GPU-only)
reshaped_bsa_bs = self.BSA_BLOCK_SIZE // self.stride # e.g., 128/8 = 16
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm
segment_size = min(4096, reshaped_bsa_bs)
with nvtx.range("xattn_estimate_softmax"): with nvtx.range("xattn_estimate_softmax"):
block_sums_fine = softmax_fuse_block_sum( block_sums = softmax_fuse_block_sum(
attn_scores, attn_scores,
reshaped_est_bs, # Use optimized estimate block size (128 vs 512) reshaped_bsa_bs,
segment_size, segment_size,
chunk_start=0, chunk_start=chunk_start,
chunk_end=q_reshaped_len, chunk_end=chunk_end,
real_q_len=q_reshaped_len, real_q_len=real_q_len,
scale=scale, scale=scale,
is_causal=False, # Historical blocks are all before current chunk is_causal=True, # Causal for consistent with GPU-only
) )
# block_sums_fine shape: [batch, heads, q_est_blocks, k_est_blocks] # block_sums shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
# where k_est_blocks = len(available_blocks) * ratio
# Step 3: Aggregate to CPU block level (hierarchical sum) # ============================================================
# This is mathematically equivalent to direct computation but much faster # Step 6: Use find_blocks_chunked to generate BSA-level mask
batch_size_bs, num_heads_bs, q_est_blocks, k_est_blocks = block_sums_fine.shape # ============================================================
num_cpu_blocks = len(available_blocks) # Calculate BSA block indices
q_bsa_blocks = (padded_q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
total_k_bsa_blocks = (padded_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
historical_k_bsa_blocks = num_historical_blocks * bsa_per_cpu
with nvtx.range("xattn_estimate_aggregate"): # current_index for find_blocks_chunked: Q's block offset
# Reshape: [batch, heads, q_est, k_est] -> [batch, heads, q_est, num_cpu, ratio] q_start_bsa_block = historical_k_bsa_blocks # Q starts after historical K
block_sums_coarse = block_sums_fine.view(
batch_size_bs, num_heads_bs, q_est_blocks, num_cpu_blocks, ratio
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
# Sum over Q dimension to get total attention from Q chunk to each K block with nvtx.range("xattn_find_blocks"):
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks] mask = find_blocks_chunked(
block_sums,
current_index=q_start_bsa_block, # Q's position in BSA blocks
threshold=self.threshold,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=True, # Causal for block-level mask
)
# mask shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
# Step 4: Select blocks using score + threshold (replaces mask + majority voting) # ============================================================
# This is simpler and more direct than the original mask-based approach # Step 7: Extract mask portions and record density
with nvtx.range("xattn_estimate_select"): # ============================================================
# Average scores across heads (GQA-aware: all heads contribute equally) B, H, Q_bsa, K_bsa_total = mask.shape
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
# Normalize to get attention distribution # Calculate valid Q blocks (excluding padding)
total_score = scores_per_block.sum() valid_q_bsa = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
if total_score > 0: valid_curr_k_bsa = (curr_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
score_ratio = scores_per_block / total_score
# 7a: Record historical blocks density
# IMPORTANT: For historical blocks, apply causal mask to match GPU-only density calculation!
# Q block i (global position = q_start_bsa_block + i) can see historical K block j
# only if j <= q_start_bsa_block + i (causal constraint)
mask_historical = mask[:, :, :valid_q_bsa, :historical_k_bsa_blocks]
if historical_k_bsa_blocks > 0:
# Create causal mask for historical blocks
# Q_global[i] = q_start_bsa_block + i, K[j] = j
# Causal: j <= Q_global[i] => j <= q_start_bsa_block + i
q_global_indices = torch.arange(valid_q_bsa, device=mask.device) + q_start_bsa_block
k_indices = torch.arange(historical_k_bsa_blocks, device=mask.device)
# Q at position q_global_indices[i] can see K at position k_indices[j] if k_indices[j] <= q_global_indices[i]
causal_mask_historical = k_indices.unsqueeze(0) <= q_global_indices.unsqueeze(1) # [valid_q_bsa, historical_k_bsa_blocks]
# Count positions within causal mask only
total_historical_causal = causal_mask_historical.sum().item() * B * H
selected_historical = (mask_historical & causal_mask_historical.unsqueeze(0).unsqueeze(0)).sum().item()
if total_historical_causal > 0:
DensityObserver.record_counts(layer_id, selected_historical, total_historical_causal)
# 7b: Record current chunk density (causal, to align with GPU-only mode)
# Current chunk is the portion after historical blocks
if valid_curr_k_bsa > 0:
# Extract current chunk mask (only valid portion, not padded)
mask_current = mask[:, :, :valid_q_bsa, historical_k_bsa_blocks:historical_k_bsa_blocks + valid_curr_k_bsa]
q_dim = mask_current.shape[2]
k_dim = mask_current.shape[3]
# Create causal mask (lower triangular)
# For current chunk: Q[i] can see K[j] where j <= i (standard causal)
causal_mask = torch.tril(torch.ones(q_dim, k_dim, device=mask.device, dtype=torch.bool))
# Count positions within causal mask only
total_current_causal = causal_mask.sum().item() * B * H
selected_current = (mask_current & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
if total_current_causal > 0:
DensityObserver.record_counts(layer_id, selected_current, total_current_causal)
# Step 7.5: Save historical mask to pre-allocated buffer for compute_chunked_prefill
# Use full Q_bsa (padded) for buffer, not valid_q_bsa
mask_historical_full = mask[:, :, :, :historical_k_bsa_blocks]
if self._prefill_mask_buffer is not None:
# Only save historical portion of mask
self._prefill_mask_buffer[:, :, :Q_bsa, :historical_k_bsa_blocks].copy_(mask_historical_full)
self._current_mask_q_bsa = Q_bsa
self._current_mask_k_bsa = historical_k_bsa_blocks
# ============================================================
# Step 8: Aggregate mask to CPU block level (union of heads)
# ============================================================
# Only aggregate historical blocks (current chunk is always full attention)
num_cpu_blocks = num_historical_blocks
with nvtx.range("xattn_aggregate_mask"):
# Reshape historical mask: [B, H, Q_bsa, historical_k_bsa] -> [B, H, Q_bsa, num_cpu, bsa_per_cpu]
# Use full Q_bsa (not valid_q_bsa) for aggregation
mask_per_cpu = mask_historical_full.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu)
# Union across: bsa_per_cpu, Q_bsa, heads -> [B, num_cpu]
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1) # [B, num_cpu]
# Get selected indices
selected_indices = cpu_needed[0].nonzero().squeeze(-1).tolist()
if isinstance(selected_indices, int):
selected_indices = [selected_indices]
# Handle empty available_blocks case (first chunk)
if available_blocks:
selected_block_ids = [available_blocks[i] for i in selected_indices]
else: else:
# Edge case: all zeros, select all blocks selected_block_ids = []
selected_block_ids = list(available_blocks)
if layer_id == 0 and available_blocks:
self._stats_total_available_blocks += len(available_blocks)
self._stats_total_selected_blocks += len(selected_block_ids)
self._stats_num_chunks += 1
return selected_block_ids
# Sort by score (descending) and select until threshold is reached
sorted_indices = torch.argsort(score_ratio, descending=True)
cumsum = 0.0
selected_indices = set()
for idx in sorted_indices.tolist():
selected_indices.add(idx)
cumsum += score_ratio[idx].item()
if cumsum >= self.threshold:
break
# Map indices back to block IDs
selected_block_ids = [available_blocks[i] for i in sorted(selected_indices)]
# Always include first block (sink) and last block for safety # Always include first block (sink) and last block for safety
if available_blocks and available_blocks[0] not in selected_block_ids: if available_blocks and available_blocks[0] not in selected_block_ids:
@@ -599,6 +758,14 @@ class XAttentionBSAPolicy(SparsePolicy):
if available_blocks and available_blocks[-1] not in selected_block_ids: if available_blocks and available_blocks[-1] not in selected_block_ids:
selected_block_ids.append(available_blocks[-1]) selected_block_ids.append(available_blocks[-1])
# Record communication density (CPU block granularity) - only if there are historical blocks
if available_blocks:
DensityObserver.record_comm_density(
layer_id,
selected_cpu_blocks=len(selected_block_ids),
total_cpu_blocks=len(available_blocks),
)
# Update statistics (only for layer 0 to avoid overcounting) # Update statistics (only for layer 0 to avoid overcounting)
if layer_id == 0 and available_blocks: if layer_id == 0 and available_blocks:
self._stats_total_available_blocks += len(available_blocks) self._stats_total_available_blocks += len(available_blocks)
@@ -611,7 +778,7 @@ class XAttentionBSAPolicy(SparsePolicy):
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}") f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
# Free intermediate tensors to prevent memory leak # Free intermediate tensors to prevent memory leak
del attn_scores, block_sums_fine, block_sums_coarse, cpu_block_scores, scores_per_block del attn_scores, block_sums, mask, mask_historical_full
return selected_block_ids return selected_block_ids
@@ -637,6 +804,10 @@ class XAttentionBSAPolicy(SparsePolicy):
2. Compute attention to current chunk 2. Compute attention to current chunk
3. Merge all results 3. Merge all results
Note: The BSA-level mask is saved in self._prefill_mask_buffer by select_blocks().
Currently we use flash_attn_with_lse for computation (supports LSE merge).
TODO: Optimize to use BSA kernel with the saved mask for per-head sparse attention.
Args: Args:
q: Query tensor [seq_len, num_heads, head_dim] q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer) k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
@@ -667,6 +838,11 @@ class XAttentionBSAPolicy(SparsePolicy):
# Use the pre-selected blocks directly # Use the pre-selected blocks directly
cpu_block_table = selected_blocks cpu_block_table = selected_blocks
# Note: BSA mask is available in self._prefill_mask_buffer (saved by select_blocks)
# Mask shape: [1, num_heads, Q_bsa, K_bsa] where Q_bsa = self._current_mask_q_bsa
# Selected indices: self._selected_cpu_indices, bsa_per_cpu: self._bsa_per_cpu
# TODO: Use this mask with BSA kernel for per-head sparse attention optimization
if cpu_block_table: if cpu_block_table:
with nvtx.range("xattn_compute_historical"): with nvtx.range("xattn_compute_historical"):
load_slots = list(range(offload_engine.num_ring_slots)) load_slots = list(range(offload_engine.num_ring_slots))

View File

@@ -221,8 +221,7 @@ class Attention(nn.Module):
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill) # Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
selected_blocks = [] # Always call select_blocks even for first chunk (cpu_block_table may be empty)
if cpu_block_table:
num_chunks = current_chunk_idx + 1 num_chunks = current_chunk_idx + 1
policy_ctx = PolicyContext( policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx, query_chunk_idx=current_chunk_idx,
@@ -231,9 +230,9 @@ class Attention(nn.Module):
query=q, # Pass query for sparse policies that need it query=q, # Pass query for sparse policies that need it
is_prefill=True, is_prefill=True,
block_size=kvcache_manager.block_size, block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, total_kv_len=len(cpu_block_table) * kvcache_manager.block_size if cpu_block_table else 0,
) )
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx) selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks") logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path # [DEBUG] Verify execution path
@@ -320,7 +319,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size, block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
) )
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx) selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks") logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path # [DEBUG] Verify execution path

View File

@@ -1,13 +1,22 @@
""" """
DensityObserver - Sparse Attention Density 统计 Observer。 DensityObserver - Sparse Attention Density 统计 Observer。
统计每层的 sparse attention density: 统计两种 density:
- density = selected_blocks / total_causal_blocks 1. Compute Density (计算密度): 基于 BSA block size (128)
- 在 causal attention 下,只计算下三角区域 - density = selected_bsa_blocks / total_causal_bsa_blocks
- GPU-only 和 Offload 模式应该一致
2. Communication Density (通信密度): 基于 CPU block size (如 4096)
- comm_density = selected_cpu_blocks / total_cpu_blocks
- 仅用于 Offload 模式,由于粒度更粗,必然 >= compute density
统计位置: 统计位置:
- GPU-only: xattn_bsa.py compute_prefill() - GPU-only: xattn_bsa.py compute_prefill() - 只记录 compute density
- Offload: xattn_bsa.py select_blocks() - Offload: xattn_bsa.py select_blocks() - 记录两种 density
对于 Offload 模式的 Density 计算:
- 不是简单的 avg 或 min
- 而是 sum(selected) / sum(total),正确处理不同 chunk 大小的权重
""" """
from typing import List, Dict, Optional, Tuple from typing import List, Dict, Optional, Tuple
@@ -26,16 +35,26 @@ class DensityObserver(Observer):
DensityObserver.complete_reset() DensityObserver.complete_reset()
# ... run inference ... # ... run inference ...
DensityObserver.record(layer_id, mask, causal=True) DensityObserver.record(layer_id, mask, causal=True)
# 或者使用累积模式 (offload):
DensityObserver.record_counts(layer_id, selected, total)
# ... # ...
DensityObserver.print_summary() DensityObserver.print_summary()
""" """
_enabled: bool = False # 默认禁用 _enabled: bool = False # 默认禁用
# 每层的 density 记录 # 每层的 compute density 记录 (BSA block 粒度)
# key: layer_id, value: list of density values (每次 prefill chunk 一个) # key: layer_id, value: list of density values (每次 prefill chunk 一个)
_layer_densities: Dict[int, List[float]] = {} _layer_densities: Dict[int, List[float]] = {}
# 每层的 communication density 记录 (CPU block 粒度,仅 offload 模式)
_layer_comm_densities: Dict[int, List[float]] = {}
# 累积模式: 记录 selected/total counts (用于 offload 模式)
# 这样可以在所有 chunks 完成后正确计算 density = sum(selected) / sum(total)
_layer_selected_counts: Dict[int, List[int]] = {}
_layer_total_counts: Dict[int, List[int]] = {}
# Mask shape 记录 (用于调试) # Mask shape 记录 (用于调试)
_last_q_blocks: int = 0 _last_q_blocks: int = 0
_last_k_blocks: int = 0 _last_k_blocks: int = 0
@@ -56,7 +75,7 @@ class DensityObserver(Observer):
causal: bool = True, causal: bool = True,
) -> float: ) -> float:
""" """
记录一层的 density。 记录一层的 density (适用于 GPU-only 模式)
Args: Args:
layer_id: 层 ID layer_id: 层 ID
@@ -82,6 +101,72 @@ class DensityObserver(Observer):
return density return density
@classmethod
def record_counts(
cls,
layer_id: int,
selected_blocks: int,
total_blocks: int,
) -> None:
"""
记录一层的 selected/total block counts (适用于 offload 累积模式)。
使用累积计数而不是直接计算 density这样在所有 chunks 处理完后可以正确计算:
overall_density = sum(selected) / sum(total)
这比 avg(density) 更准确,因为不同 chunk 的 Q 和 K 长度不同。
Args:
layer_id: 层 ID
selected_blocks: 这个 chunk 选中的 blocks 数量
total_blocks: 这个 chunk 的 total possible blocks 数量
"""
if not cls._enabled:
return
# 初始化列表
if layer_id not in cls._layer_selected_counts:
cls._layer_selected_counts[layer_id] = []
if layer_id not in cls._layer_total_counts:
cls._layer_total_counts[layer_id] = []
# 累积记录
cls._layer_selected_counts[layer_id].append(selected_blocks)
cls._layer_total_counts[layer_id].append(total_blocks)
@classmethod
def record_comm_density(
cls,
layer_id: int,
selected_cpu_blocks: int,
total_cpu_blocks: int,
) -> float:
"""
记录一层的 communication density (CPU block 粒度)。
Args:
layer_id: 层 ID
selected_cpu_blocks: 选中的 CPU blocks 数量
total_cpu_blocks: 总 CPU blocks 数量
Returns:
communication density 值
"""
if not cls._enabled:
return 0.0
if total_cpu_blocks == 0:
return 1.0
comm_density = selected_cpu_blocks / total_cpu_blocks
# 记录
if layer_id not in cls._layer_comm_densities:
cls._layer_comm_densities[layer_id] = []
cls._layer_comm_densities[layer_id].append(comm_density)
return comm_density
@classmethod @classmethod
def _compute_density(cls, mask: torch.Tensor, causal: bool) -> float: def _compute_density(cls, mask: torch.Tensor, causal: bool) -> float:
"""计算 mask 的 density""" """计算 mask 的 density"""
@@ -107,22 +192,63 @@ class DensityObserver(Observer):
def complete_reset(cls) -> None: def complete_reset(cls) -> None:
"""重置所有统计""" """重置所有统计"""
cls._layer_densities = {} cls._layer_densities = {}
cls._layer_comm_densities = {}
cls._layer_selected_counts = {}
cls._layer_total_counts = {}
cls._last_q_blocks = 0 cls._last_q_blocks = 0
cls._last_k_blocks = 0 cls._last_k_blocks = 0
cls._mode = "unknown" cls._mode = "unknown"
@classmethod @classmethod
def get_per_layer_density(cls) -> Dict[int, float]: def get_per_layer_density(cls) -> Dict[int, float]:
"""获取每层的平均 density""" """
获取每层的 density。
对于累积模式 (offload): density = sum(selected) / sum(total)
对于直接记录模式 (gpu_only): density = avg(density_values)
"""
result = {} result = {}
# 优先使用累积模式 (offload)
if cls._layer_selected_counts:
for layer_id in cls._layer_selected_counts:
selected_list = cls._layer_selected_counts.get(layer_id, [])
total_list = cls._layer_total_counts.get(layer_id, [])
total_selected = sum(selected_list)
total_total = sum(total_list)
if total_total > 0:
result[layer_id] = total_selected / total_total
else:
# 直接记录模式 (gpu_only)
for layer_id, densities in cls._layer_densities.items(): for layer_id, densities in cls._layer_densities.items():
if densities: if densities:
result[layer_id] = sum(densities) / len(densities) result[layer_id] = sum(densities) / len(densities)
return result return result
@classmethod @classmethod
def get_overall_density(cls) -> float: def get_overall_density(cls) -> float:
"""获取所有层的平均 density""" """
获取所有层的总体 compute density。
对于累积模式 (offload): density = sum(all_selected) / sum(all_total)
对于直接记录模式 (gpu_only): density = avg(all_density_values)
注意: 总体 density 不是简单的 avg(per_layer_density)
而是 sum(all_selected) / sum(all_total),这样可以正确处理权重。
"""
# 优先使用累积模式 (offload)
if cls._layer_selected_counts:
total_selected = 0
total_total = 0
for layer_id in cls._layer_selected_counts:
total_selected += sum(cls._layer_selected_counts[layer_id])
total_total += sum(cls._layer_total_counts.get(layer_id, []))
if total_total > 0:
return total_selected / total_total
return 0.0
# 直接记录模式 (gpu_only)
all_densities = [] all_densities = []
for densities in cls._layer_densities.values(): for densities in cls._layer_densities.values():
all_densities.extend(densities) all_densities.extend(densities)
@@ -130,6 +256,16 @@ class DensityObserver(Observer):
return 0.0 return 0.0
return sum(all_densities) / len(all_densities) return sum(all_densities) / len(all_densities)
@classmethod
def get_overall_comm_density(cls) -> float:
"""获取所有层的平均 communication density"""
all_densities = []
for densities in cls._layer_comm_densities.values():
all_densities.extend(densities)
if not all_densities:
return 0.0
return sum(all_densities) / len(all_densities)
@classmethod @classmethod
def get_summary(cls) -> dict: def get_summary(cls) -> dict:
"""返回统计摘要""" """返回统计摘要"""
@@ -160,8 +296,13 @@ class DensityObserver(Observer):
per_layer = cls.get_per_layer_density() per_layer = cls.get_per_layer_density()
overall = cls.get_overall_density() overall = cls.get_overall_density()
min_layer, min_density = cls.get_min_density() min_layer, min_density = cls.get_min_density()
overall_comm = cls.get_overall_comm_density()
print(f"[DensityObserver] Mode: {cls._mode}") print(f"[DensityObserver] Mode: {cls._mode}")
print(f" Overall density: {overall:.4f}") print(f" Compute density: {overall:.4f} (min: {min_density:.4f} @ layer {min_layer})")
print(f" Min density: {min_density:.4f} (layer {min_layer})") if overall_comm > 0:
print(f" Comm density: {overall_comm:.4f}")
print(f" Num layers: {len(per_layer)}") print(f" Num layers: {len(per_layer)}")
# 输出 layer 0 的 density 用于对比
if 0 in per_layer:
print(f" Layer 0 density: {per_layer[0]:.6f}")

View File

@@ -41,9 +41,9 @@ K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
for i in range(q_len): for i in range(q_len):
if i % 2 == 0: if i % 2 == 0:
Q[0, 0, i, :] = 1 Q[0, 0, i, :] = 1 * (i // stride + 1)
else: else:
Q[0, 0, i, :] = 2 Q[0, 0, i, :] = 2 * (i // stride + 1)
for i in range(kv_len): for i in range(kv_len):
if i % 2 == 0: if i % 2 == 0:
@@ -74,8 +74,11 @@ for k_chunk_idx in range(num_k_chunks):
Q, K_chunk, stride, Q, K_chunk, stride,
chunk_start=0, chunk_start=0,
chunk_end=q_reshaped_len, chunk_end=q_reshaped_len,
is_causal=False is_causal=True
) )
__import__('pdb').set_trace()
attn_scores_list.append(attn_chunk) attn_scores_list.append(attn_chunk)
# 拼接所有 K chunks 的结果 # 拼接所有 K chunks 的结果