WIP: Enhance sparse attention with density tracking and block selection improvements

- Added analysis documentation for xattn density alignment.
- Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration.
- Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection.
- Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks.
- Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling.
- Introduced DensityObserver to track compute and communication density for sparse attention layers.
- Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios.
- Added tests for attention kernel behavior with enhanced input patterns.
This commit is contained in:
Zijie Tian
2026-01-31 14:48:23 +08:00
parent f6ac4ccdde
commit 2e96d1d97d
9 changed files with 490 additions and 152 deletions

View File

@@ -47,6 +47,8 @@ class FullAttentionPolicy(SparsePolicy):
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]:
"""Return all blocks - no sparsity."""
# Update statistics (only for layer 0 to avoid overcounting)

View File

@@ -142,6 +142,8 @@ class SparsePolicy(ABC):
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]:
"""
Select which KV blocks to load for the current query chunk.
@@ -158,6 +160,8 @@ class SparsePolicy(ABC):
to load KV to make selection decisions).
ctx: PolicyContext with information about the current query
chunk, layer, phase (prefill/decode), etc.
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk
Returns:
List of block IDs to load (must be a subset of available_blocks).

View File

@@ -191,13 +191,26 @@ class QuestPolicy(SparsePolicy):
def select_blocks(
self,
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]:
"""
Select Top-K blocks based on query-key similarity bounds.
If query is not available (some prefill scenarios), falls back
to loading all blocks.
Args:
available_blocks: List of CPU block IDs
offload_engine: OffloadEngine for loading KV (unused in Quest)
ctx: PolicyContext with metadata
q: Query tensor [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused in Quest, uses metadata instead)
Returns:
Selected block IDs
"""
if self.metadata is None:
raise RuntimeError(
@@ -211,7 +224,7 @@ class QuestPolicy(SparsePolicy):
if n <= self.config.threshold_blocks:
return available_blocks
if ctx.query is None:
if q is None:
# No query available - cannot compute scores
return available_blocks
@@ -221,11 +234,10 @@ class QuestPolicy(SparsePolicy):
)
# Metadata is already on GPU, same device as query
device = ctx.query.device
device = q.device
# Compute upper bound scores
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
q = ctx.query
# query shape: [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
if q.dim() == 4:
# Prefill: use mean over sequence length
q = q.mean(dim=1) # [1, num_heads, head_dim]

View File

@@ -135,6 +135,21 @@ class XAttentionBSAPolicy(SparsePolicy):
self._v_expanded: torch.Tensor | None = None
self._max_seq_len: int = 0
# Pre-allocated mask buffer for chunked prefill (offload mode)
# Stores BSA-level mask from select_blocks for use in compute_chunked_prefill
# Shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
self._prefill_mask_buffer: torch.Tensor | None = None
self._current_mask_q_bsa: int = 0 # Current Q BSA blocks in buffer
self._current_mask_k_bsa: int = 0 # Current K BSA blocks in buffer
# Selected block indices for mask extraction in compute_chunked_prefill
# Stores the indices of selected CPU blocks in available_blocks
self._selected_cpu_indices: List[int] = []
self._bsa_per_cpu: int = 0 # BSA blocks per CPU block
#> Debug: store all K cache
self._debug_k_full: torch.Tensor | None = None
def alloc_policy_metadata(
self,
num_heads: int,
@@ -162,7 +177,17 @@ class XAttentionBSAPolicy(SparsePolicy):
dtype: Data type
device: Target device
"""
# Only allocate if GQA (num_heads != num_kv_heads)
# Pre-allocate mask buffer for chunked prefill (offload mode)
# mask shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
# This is needed regardless of GQA
max_q_bsa_blocks = self.chunk_size // self.BSA_BLOCK_SIZE
max_k_bsa_blocks = max_seq_len // self.BSA_BLOCK_SIZE
mask_shape = (1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks)
self._prefill_mask_buffer = torch.empty(mask_shape, dtype=torch.bool, device=device)
mask_memory_mb = num_heads * max_q_bsa_blocks * max_k_bsa_blocks / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated mask buffer: shape={mask_shape}, memory={mask_memory_mb:.1f} MB")
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads)
if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
return
@@ -176,6 +201,9 @@ class XAttentionBSAPolicy(SparsePolicy):
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
#DEBUG : buffer for save all K cache.
self._debug_k_full = torch.empty((1, num_heads, max_seq_len, head_dim), dtype=dtype, device=device)
# =========================================================================
# GPU-only methods (non-chunked)
@@ -401,33 +429,42 @@ class XAttentionBSAPolicy(SparsePolicy):
available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]:
"""
Compute attention scores for all available blocks using flat_group_gemm,
then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks.
This method:
1. Loads each K block from CPU
2. Computes Q@K^T attention scores using XAttention stride reshape
3. Applies softmax_fuse_block_sum to get block-level attention
4. Uses find_blocks_chunked to select blocks based on threshold
This method aligns with GPU-only xattn_estimate_chunked:
1. Loads each K block from CPU (historical blocks)
2. Gets current chunk K from prefill buffer
3. Concatenates [historical K, current chunk K] for correct softmax normalization
4. Uses causal=True with correct chunk_start for position-aware masking
5. Only selects from historical blocks (current chunk is always full attention)
Args:
available_blocks: List of CPU block IDs
available_blocks: List of CPU block IDs (historical blocks only)
offload_engine: OffloadEngine for loading blocks
ctx: PolicyContext with query tensor and metadata
ctx: PolicyContext with metadata
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk (used for estimation)
Returns:
Selected block IDs based on attention threshold
"""
if not available_blocks or ctx.query is None:
if q is None:
return available_blocks
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked
import math
layer_id = ctx.layer_id
q = ctx.query # [seq_len, num_heads, head_dim]
# Use passed q parameter instead of ctx.query
# Set DensityObserver mode on first layer
if layer_id == 0:
DensityObserver.set_mode("offload")
# Convert Q to [batch, heads, seq_len, head_dim]
# q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim]
@@ -454,18 +491,37 @@ class XAttentionBSAPolicy(SparsePolicy):
q_reshaped_len = padded_q_len // self.stride
# Use a single slot for loading (synchronous mode for simplicity)
# Get block size from context
block_size = ctx.block_size # tokens per CPU block (e.g., 4096)
reshaped_block_size = block_size // self.stride # e.g., 4096/8 = 512
# ============================================================
# Step 1: Compute chunk_start and related parameters
# ============================================================
# chunk_start = Q's global position in reshaped space
# Q starts at position: num_historical_blocks * block_size
num_historical_blocks = len(available_blocks)
historical_k_len = num_historical_blocks * block_size
chunk_start = historical_k_len // self.stride # Q's position in reshaped space
chunk_end = chunk_start + q_reshaped_len
# For valid Q length tracking (excluding padding)
valid_q_reshaped = (q_len + self.stride - 1) // self.stride
real_q_len = chunk_start + valid_q_reshaped
# ============================================================
# Step 2: Pipeline load historical K blocks and compute attn_scores
# ============================================================
# Key design: Load each block, compute immediately, then release
# This avoids storing all K in GPU memory at once (offload-friendly)
slot = 0
attn_scores_list = []
BLOCK_N = 128
k_alignment = self.stride * BLOCK_N
# Get block size from context
block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
with nvtx.range("xattn_estimate_gemm"):
with nvtx.range("xattn_estimate_historical"):
for cpu_block_id in available_blocks:
# Load only K from CPU to GPU (V not needed for estimate)
# This saves 50% communication in the estimate phase
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot)
@@ -473,125 +529,228 @@ class XAttentionBSAPolicy(SparsePolicy):
k_block = offload_engine.get_k_for_slot(slot)
# Convert K to [batch, heads, k_len, head_dim]
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]
K_chunk = k_block.transpose(1, 2)
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
# Handle GQA: expand K heads to match Q heads
num_kv_heads = K_chunk.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
#> DEBUG: save all K cache
start_pos = cpu_block_id * block_size
self._debug_k_full[:, :, start_pos:start_pos + block_size, :].copy_(K_chunk)
# # Pad K if necessary
# k_len = K_chunk.shape[2]
# if k_len < k_alignment:
# pad_size = k_alignment - k_len
# K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
# Pad K if necessary (k_len must be divisible by stride * BLOCK_N)
k_len = K_chunk.shape[2]
BLOCK_N = 128
k_alignment = self.stride * BLOCK_N
if k_len < k_alignment:
# K too short, pad it
pad_size = k_alignment - k_len
K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
# Compute attention scores using flat_group_gemm_fuse_reshape
# Output: [batch, heads, q_len/stride, k_len/stride]
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_chunk, self.stride,
chunk_start=0,
chunk_end=q_reshaped_len,
is_causal=False
)
attn_scores_list.append(attn_chunk)
# # Compute attention scores for this historical block
# # Historical blocks: all positions < Q, so Q always sees them (full attention)
# # Use LOCAL chunk_start=0 to match test_xattn_k_chunked.py behavior
# attn_chunk = flat_group_gemm_fuse_reshape(
# Q, K_chunk, self.stride,
# chunk_start=0, # Local: same as test
# chunk_end=q_reshaped_len,
# is_causal=False, # Historical K: all visible to Q
# )
# attn_scores_list.append(attn_chunk)
# Mark slot as done for reuse
offload_engine.record_slot_compute_done(slot)
num_kv_heads = k.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
k_repeated = k.repeat_interleave(num_groups, dim=1).unsqueeze(0).transpose(1, 2) # [1, num_heads, historical_k_len, head_dim]
self._debug_k_full[:, :, historical_k_len:historical_k_len + q_len, :].copy_(k_repeated)
if layer_id == 0:
__import__('pdb').set_trace()
# Concatenate all attention scores along K dimension
# Each chunk: [1, heads, q_reshaped_len, block_reshaped_len]
# Result: [1, heads, q_reshaped_len, total_k_reshaped_len]
# ============================================================
# Step 3: Get current chunk K and compute its attn_scores
# ============================================================
with nvtx.range("xattn_estimate_current"):
# Current chunk K is in prefill buffer (already on GPU)
k_curr, _ = offload_engine.get_prefill_buffer_slice(layer_id, q_len)
# k_curr: [1, q_len, num_kv_heads, head_dim] -> [1, num_kv_heads, q_len, head_dim]
K_current = k_curr.transpose(1, 2)
# Handle GQA for current chunk K
num_kv_heads = K_current.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_current = K_current.repeat_interleave(num_groups, dim=1)
# Pad current K if necessary
curr_k_len = K_current.shape[2]
padded_curr_k_len = ((curr_k_len + k_alignment - 1) // k_alignment) * k_alignment
if padded_curr_k_len != curr_k_len:
pad_size = padded_curr_k_len - curr_k_len
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, pad_size), value=0)
# Compute attention scores for current chunk
# IMPORTANT: Use LOCAL coordinates (0 to q_reshaped_len) for current chunk!
# Because K_current only contains current chunk K (not full sequence),
# block_n in kernel starts from 0. Using global chunk_start would cause
# incorrect causal mask (Q would see K blocks it shouldn't).
attn_current = flat_group_gemm_fuse_reshape(
Q, K_current, self.stride,
chunk_start=0, # Local: Q starts at 0 relative to K_current
chunk_end=q_reshaped_len, # Local: Q ends at q_reshaped_len
is_causal=True, # Current chunk: apply causal mask
)
attn_scores_list.append(attn_current)
del K_current
# ============================================================
# Step 4: Concatenate all attn_scores
# ============================================================
if not attn_scores_list:
return available_blocks
attn_scores = torch.cat(attn_scores_list, dim=-1)
# Free intermediate list immediately
del attn_scores_list
# Step 2: Apply softmax_fuse_block_sum with hierarchical aggregation
# Use smaller estimate_block_size (1024) for 15x faster softmax kernel,
# then aggregate to CPU block level (4096).
#
# Hierarchical approach:
# 1. softmax_fuse_block_sum with estimate_block_size (1024) -> fine-grained scores
# 2. Aggregate: reshape + sum -> CPU block level scores
# 3. Select blocks based on score + threshold (NOT mask + voting)
cpu_block_size = block_size # e.g., 4096
estimate_bs = self.estimate_block_size # e.g., 1024 (15x faster)
ratio = cpu_block_size // estimate_bs # e.g., 4
# Calculate padded K length for later use
padded_k_len = historical_k_len + padded_curr_k_len
# Use estimate_block_size for softmax kernel (optimized)
reshaped_est_bs = estimate_bs // self.stride # e.g., 1024/8 = 128
norm = 1.0 # Normalization factor
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling
segment_size = min(4096, reshaped_est_bs)
# ============================================================
# Step 5: Apply softmax_fuse_block_sum with causal=True
# ============================================================
cpu_block_size = block_size # e.g., 4096
bsa_per_cpu = cpu_block_size // self.BSA_BLOCK_SIZE # e.g., 4096/128 = 32
# Use BSA_BLOCK_SIZE for block aggregation (aligned with GPU-only)
reshaped_bsa_bs = self.BSA_BLOCK_SIZE // self.stride # e.g., 128/8 = 16
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm
segment_size = min(4096, reshaped_bsa_bs)
with nvtx.range("xattn_estimate_softmax"):
block_sums_fine = softmax_fuse_block_sum(
block_sums = softmax_fuse_block_sum(
attn_scores,
reshaped_est_bs, # Use optimized estimate block size (128 vs 512)
reshaped_bsa_bs,
segment_size,
chunk_start=0,
chunk_end=q_reshaped_len,
real_q_len=q_reshaped_len,
chunk_start=chunk_start,
chunk_end=chunk_end,
real_q_len=real_q_len,
scale=scale,
is_causal=False, # Historical blocks are all before current chunk
is_causal=True, # Causal for consistent with GPU-only
)
# block_sums_fine shape: [batch, heads, q_est_blocks, k_est_blocks]
# where k_est_blocks = len(available_blocks) * ratio
# block_sums shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
# Step 3: Aggregate to CPU block level (hierarchical sum)
# This is mathematically equivalent to direct computation but much faster
batch_size_bs, num_heads_bs, q_est_blocks, k_est_blocks = block_sums_fine.shape
num_cpu_blocks = len(available_blocks)
# ============================================================
# Step 6: Use find_blocks_chunked to generate BSA-level mask
# ============================================================
# Calculate BSA block indices
q_bsa_blocks = (padded_q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
total_k_bsa_blocks = (padded_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
historical_k_bsa_blocks = num_historical_blocks * bsa_per_cpu
with nvtx.range("xattn_estimate_aggregate"):
# Reshape: [batch, heads, q_est, k_est] -> [batch, heads, q_est, num_cpu, ratio]
block_sums_coarse = block_sums_fine.view(
batch_size_bs, num_heads_bs, q_est_blocks, num_cpu_blocks, ratio
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
# current_index for find_blocks_chunked: Q's block offset
q_start_bsa_block = historical_k_bsa_blocks # Q starts after historical K
# Sum over Q dimension to get total attention from Q chunk to each K block
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks]
with nvtx.range("xattn_find_blocks"):
mask = find_blocks_chunked(
block_sums,
current_index=q_start_bsa_block, # Q's position in BSA blocks
threshold=self.threshold,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=True, # Causal for block-level mask
)
# mask shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
# Step 4: Select blocks using score + threshold (replaces mask + majority voting)
# This is simpler and more direct than the original mask-based approach
with nvtx.range("xattn_estimate_select"):
# Average scores across heads (GQA-aware: all heads contribute equally)
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
# ============================================================
# Step 7: Extract mask portions and record density
# ============================================================
B, H, Q_bsa, K_bsa_total = mask.shape
# Normalize to get attention distribution
total_score = scores_per_block.sum()
if total_score > 0:
score_ratio = scores_per_block / total_score
else:
# Edge case: all zeros, select all blocks
selected_block_ids = list(available_blocks)
if layer_id == 0 and available_blocks:
self._stats_total_available_blocks += len(available_blocks)
self._stats_total_selected_blocks += len(selected_block_ids)
self._stats_num_chunks += 1
return selected_block_ids
# Calculate valid Q blocks (excluding padding)
valid_q_bsa = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
valid_curr_k_bsa = (curr_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
# Sort by score (descending) and select until threshold is reached
sorted_indices = torch.argsort(score_ratio, descending=True)
cumsum = 0.0
selected_indices = set()
# 7a: Record historical blocks density
# IMPORTANT: For historical blocks, apply causal mask to match GPU-only density calculation!
# Q block i (global position = q_start_bsa_block + i) can see historical K block j
# only if j <= q_start_bsa_block + i (causal constraint)
mask_historical = mask[:, :, :valid_q_bsa, :historical_k_bsa_blocks]
for idx in sorted_indices.tolist():
selected_indices.add(idx)
cumsum += score_ratio[idx].item()
if cumsum >= self.threshold:
break
if historical_k_bsa_blocks > 0:
# Create causal mask for historical blocks
# Q_global[i] = q_start_bsa_block + i, K[j] = j
# Causal: j <= Q_global[i] => j <= q_start_bsa_block + i
q_global_indices = torch.arange(valid_q_bsa, device=mask.device) + q_start_bsa_block
k_indices = torch.arange(historical_k_bsa_blocks, device=mask.device)
# Q at position q_global_indices[i] can see K at position k_indices[j] if k_indices[j] <= q_global_indices[i]
causal_mask_historical = k_indices.unsqueeze(0) <= q_global_indices.unsqueeze(1) # [valid_q_bsa, historical_k_bsa_blocks]
# Map indices back to block IDs
selected_block_ids = [available_blocks[i] for i in sorted(selected_indices)]
# Count positions within causal mask only
total_historical_causal = causal_mask_historical.sum().item() * B * H
selected_historical = (mask_historical & causal_mask_historical.unsqueeze(0).unsqueeze(0)).sum().item()
if total_historical_causal > 0:
DensityObserver.record_counts(layer_id, selected_historical, total_historical_causal)
# 7b: Record current chunk density (causal, to align with GPU-only mode)
# Current chunk is the portion after historical blocks
if valid_curr_k_bsa > 0:
# Extract current chunk mask (only valid portion, not padded)
mask_current = mask[:, :, :valid_q_bsa, historical_k_bsa_blocks:historical_k_bsa_blocks + valid_curr_k_bsa]
q_dim = mask_current.shape[2]
k_dim = mask_current.shape[3]
# Create causal mask (lower triangular)
# For current chunk: Q[i] can see K[j] where j <= i (standard causal)
causal_mask = torch.tril(torch.ones(q_dim, k_dim, device=mask.device, dtype=torch.bool))
# Count positions within causal mask only
total_current_causal = causal_mask.sum().item() * B * H
selected_current = (mask_current & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
if total_current_causal > 0:
DensityObserver.record_counts(layer_id, selected_current, total_current_causal)
# Step 7.5: Save historical mask to pre-allocated buffer for compute_chunked_prefill
# Use full Q_bsa (padded) for buffer, not valid_q_bsa
mask_historical_full = mask[:, :, :, :historical_k_bsa_blocks]
if self._prefill_mask_buffer is not None:
# Only save historical portion of mask
self._prefill_mask_buffer[:, :, :Q_bsa, :historical_k_bsa_blocks].copy_(mask_historical_full)
self._current_mask_q_bsa = Q_bsa
self._current_mask_k_bsa = historical_k_bsa_blocks
# ============================================================
# Step 8: Aggregate mask to CPU block level (union of heads)
# ============================================================
# Only aggregate historical blocks (current chunk is always full attention)
num_cpu_blocks = num_historical_blocks
with nvtx.range("xattn_aggregate_mask"):
# Reshape historical mask: [B, H, Q_bsa, historical_k_bsa] -> [B, H, Q_bsa, num_cpu, bsa_per_cpu]
# Use full Q_bsa (not valid_q_bsa) for aggregation
mask_per_cpu = mask_historical_full.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu)
# Union across: bsa_per_cpu, Q_bsa, heads -> [B, num_cpu]
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1) # [B, num_cpu]
# Get selected indices
selected_indices = cpu_needed[0].nonzero().squeeze(-1).tolist()
if isinstance(selected_indices, int):
selected_indices = [selected_indices]
# Handle empty available_blocks case (first chunk)
if available_blocks:
selected_block_ids = [available_blocks[i] for i in selected_indices]
else:
selected_block_ids = []
# Always include first block (sink) and last block for safety
if available_blocks and available_blocks[0] not in selected_block_ids:
@@ -599,6 +758,14 @@ class XAttentionBSAPolicy(SparsePolicy):
if available_blocks and available_blocks[-1] not in selected_block_ids:
selected_block_ids.append(available_blocks[-1])
# Record communication density (CPU block granularity) - only if there are historical blocks
if available_blocks:
DensityObserver.record_comm_density(
layer_id,
selected_cpu_blocks=len(selected_block_ids),
total_cpu_blocks=len(available_blocks),
)
# Update statistics (only for layer 0 to avoid overcounting)
if layer_id == 0 and available_blocks:
self._stats_total_available_blocks += len(available_blocks)
@@ -611,7 +778,7 @@ class XAttentionBSAPolicy(SparsePolicy):
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
# Free intermediate tensors to prevent memory leak
del attn_scores, block_sums_fine, block_sums_coarse, cpu_block_scores, scores_per_block
del attn_scores, block_sums, mask, mask_historical_full
return selected_block_ids
@@ -637,6 +804,10 @@ class XAttentionBSAPolicy(SparsePolicy):
2. Compute attention to current chunk
3. Merge all results
Note: The BSA-level mask is saved in self._prefill_mask_buffer by select_blocks().
Currently we use flash_attn_with_lse for computation (supports LSE merge).
TODO: Optimize to use BSA kernel with the saved mask for per-head sparse attention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
@@ -667,6 +838,11 @@ class XAttentionBSAPolicy(SparsePolicy):
# Use the pre-selected blocks directly
cpu_block_table = selected_blocks
# Note: BSA mask is available in self._prefill_mask_buffer (saved by select_blocks)
# Mask shape: [1, num_heads, Q_bsa, K_bsa] where Q_bsa = self._current_mask_q_bsa
# Selected indices: self._selected_cpu_indices, bsa_per_cpu: self._bsa_per_cpu
# TODO: Use this mask with BSA kernel for per-head sparse attention optimization
if cpu_block_table:
with nvtx.range("xattn_compute_historical"):
load_slots = list(range(offload_engine.num_ring_slots))