From 2e96d1d97d11096fe0e943992ee53acff39ae41d Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Sat, 31 Jan 2026 14:48:23 +0800 Subject: [PATCH] WIP: Enhance sparse attention with density tracking and block selection improvements - Added analysis documentation for xattn density alignment. - Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration. - Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection. - Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks. - Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling. - Introduced DensityObserver to track compute and communication density for sparse attention layers. - Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios. - Added tests for attention kernel behavior with enhanced input patterns. --- CLAUDE.md | 1 + nanovllm/engine/model_runner.py | 20 +- nanovllm/kvcache/sparse/full_policy.py | 2 + nanovllm/kvcache/sparse/policy.py | 4 + nanovllm/kvcache/sparse/quest.py | 20 +- nanovllm/kvcache/sparse/xattn_bsa.py | 388 ++++++++++++++++++------- nanovllm/layers/attention.py | 29 +- nanovllm/utils/density_observer.py | 169 ++++++++++- tests/test_xattn_kernels.py | 9 +- 9 files changed, 490 insertions(+), 152 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index e643971..c9e8344 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -37,6 +37,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L | [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析,softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x | | [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL),≤10B 推荐模型 | | [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 | +| [`docs/xattn_density_alignment_analysis.md`](docs/xattn_density_alignment_analysis.md) | 📊 ANALYSIS: GPU-only vs Offload 模式 density 对齐分析,chunked softmax 边界效应,5-7% 差异根因 | ## Rules Index diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 72e2e77..8fb8708 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -229,16 +229,16 @@ class ModelRunner: # GPU-only mode: pre-allocate policy metadata buffers # This avoids dynamic GPU memory allocation during forward pass - if not config.enable_cpu_offload: - num_heads = hf_config.num_attention_heads // self.world_size - self.kvcache_manager.sparse_policy.alloc_policy_metadata( - num_heads=num_heads, - num_kv_heads=num_kv_heads, - head_dim=head_dim, - max_seq_len=config.max_model_len, - dtype=hf_config.torch_dtype, - device=torch.device("cuda"), - ) + # if not config.enable_cpu_offload: + num_heads = hf_config.num_attention_heads // self.world_size + self.kvcache_manager.sparse_policy.alloc_policy_metadata( + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + max_seq_len=config.max_model_len, + dtype=hf_config.torch_dtype, + device=torch.device("cuda"), + ) # Log policy info (handle both enum and None cases) policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL" diff --git a/nanovllm/kvcache/sparse/full_policy.py b/nanovllm/kvcache/sparse/full_policy.py index 5b6606c..19ab12a 100644 --- a/nanovllm/kvcache/sparse/full_policy.py +++ b/nanovllm/kvcache/sparse/full_policy.py @@ -47,6 +47,8 @@ class FullAttentionPolicy(SparsePolicy): available_blocks: List[int], offload_engine: "OffloadEngine", ctx: PolicyContext, + q: torch.Tensor, + k: torch.Tensor, ) -> List[int]: """Return all blocks - no sparsity.""" # Update statistics (only for layer 0 to avoid overcounting) diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index d1c3e33..1a51f87 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -142,6 +142,8 @@ class SparsePolicy(ABC): available_blocks: List[int], offload_engine: "OffloadEngine", ctx: PolicyContext, + q: torch.Tensor, + k: torch.Tensor, ) -> List[int]: """ Select which KV blocks to load for the current query chunk. @@ -158,6 +160,8 @@ class SparsePolicy(ABC): to load KV to make selection decisions). ctx: PolicyContext with information about the current query chunk, layer, phase (prefill/decode), etc. + q: Query tensor [seq_len, num_heads, head_dim] for current chunk + k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk Returns: List of block IDs to load (must be a subset of available_blocks). diff --git a/nanovllm/kvcache/sparse/quest.py b/nanovllm/kvcache/sparse/quest.py index 42b96fc..38615cd 100644 --- a/nanovllm/kvcache/sparse/quest.py +++ b/nanovllm/kvcache/sparse/quest.py @@ -191,13 +191,26 @@ class QuestPolicy(SparsePolicy): def select_blocks( self, available_blocks: List[int], + offload_engine: "OffloadEngine", ctx: PolicyContext, + q: torch.Tensor, + k: torch.Tensor, ) -> List[int]: """ Select Top-K blocks based on query-key similarity bounds. If query is not available (some prefill scenarios), falls back to loading all blocks. + + Args: + available_blocks: List of CPU block IDs + offload_engine: OffloadEngine for loading KV (unused in Quest) + ctx: PolicyContext with metadata + q: Query tensor [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] (unused in Quest, uses metadata instead) + + Returns: + Selected block IDs """ if self.metadata is None: raise RuntimeError( @@ -211,7 +224,7 @@ class QuestPolicy(SparsePolicy): if n <= self.config.threshold_blocks: return available_blocks - if ctx.query is None: + if q is None: # No query available - cannot compute scores return available_blocks @@ -221,11 +234,10 @@ class QuestPolicy(SparsePolicy): ) # Metadata is already on GPU, same device as query - device = ctx.query.device + device = q.device # Compute upper bound scores - # query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim] - q = ctx.query + # query shape: [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim] if q.dim() == 4: # Prefill: use mean over sequence length q = q.mean(dim=1) # [1, num_heads, head_dim] diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index 04098c0..92adc48 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -135,6 +135,21 @@ class XAttentionBSAPolicy(SparsePolicy): self._v_expanded: torch.Tensor | None = None self._max_seq_len: int = 0 + # Pre-allocated mask buffer for chunked prefill (offload mode) + # Stores BSA-level mask from select_blocks for use in compute_chunked_prefill + # Shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks] + self._prefill_mask_buffer: torch.Tensor | None = None + self._current_mask_q_bsa: int = 0 # Current Q BSA blocks in buffer + self._current_mask_k_bsa: int = 0 # Current K BSA blocks in buffer + + # Selected block indices for mask extraction in compute_chunked_prefill + # Stores the indices of selected CPU blocks in available_blocks + self._selected_cpu_indices: List[int] = [] + self._bsa_per_cpu: int = 0 # BSA blocks per CPU block + + #> Debug: store all K cache + self._debug_k_full: torch.Tensor | None = None + def alloc_policy_metadata( self, num_heads: int, @@ -162,7 +177,17 @@ class XAttentionBSAPolicy(SparsePolicy): dtype: Data type device: Target device """ - # Only allocate if GQA (num_heads != num_kv_heads) + # Pre-allocate mask buffer for chunked prefill (offload mode) + # mask shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks] + # This is needed regardless of GQA + max_q_bsa_blocks = self.chunk_size // self.BSA_BLOCK_SIZE + max_k_bsa_blocks = max_seq_len // self.BSA_BLOCK_SIZE + mask_shape = (1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks) + self._prefill_mask_buffer = torch.empty(mask_shape, dtype=torch.bool, device=device) + mask_memory_mb = num_heads * max_q_bsa_blocks * max_k_bsa_blocks / (1024 * 1024) + logger.info(f"[XAttn] Pre-allocated mask buffer: shape={mask_shape}, memory={mask_memory_mb:.1f} MB") + + # Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads) if num_heads == num_kv_heads: logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})") return @@ -176,6 +201,9 @@ class XAttentionBSAPolicy(SparsePolicy): memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024) logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB") + + #DEBUG : buffer for save all K cache. + self._debug_k_full = torch.empty((1, num_heads, max_seq_len, head_dim), dtype=dtype, device=device) # ========================================================================= # GPU-only methods (non-chunked) @@ -401,33 +429,42 @@ class XAttentionBSAPolicy(SparsePolicy): available_blocks: List[int], offload_engine: "OffloadEngine", ctx: PolicyContext, + q: torch.Tensor, + k: torch.Tensor, ) -> List[int]: """ Compute attention scores for all available blocks using flat_group_gemm, then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks. - This method: - 1. Loads each K block from CPU - 2. Computes Q@K^T attention scores using XAttention stride reshape - 3. Applies softmax_fuse_block_sum to get block-level attention - 4. Uses find_blocks_chunked to select blocks based on threshold + This method aligns with GPU-only xattn_estimate_chunked: + 1. Loads each K block from CPU (historical blocks) + 2. Gets current chunk K from prefill buffer + 3. Concatenates [historical K, current chunk K] for correct softmax normalization + 4. Uses causal=True with correct chunk_start for position-aware masking + 5. Only selects from historical blocks (current chunk is always full attention) Args: - available_blocks: List of CPU block IDs + available_blocks: List of CPU block IDs (historical blocks only) offload_engine: OffloadEngine for loading blocks - ctx: PolicyContext with query tensor and metadata + ctx: PolicyContext with metadata + q: Query tensor [seq_len, num_heads, head_dim] for current chunk + k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk (used for estimation) Returns: Selected block IDs based on attention threshold """ - if not available_blocks or ctx.query is None: + if q is None: return available_blocks from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked import math layer_id = ctx.layer_id - q = ctx.query # [seq_len, num_heads, head_dim] + # Use passed q parameter instead of ctx.query + + # Set DensityObserver mode on first layer + if layer_id == 0: + DensityObserver.set_mode("offload") # Convert Q to [batch, heads, seq_len, head_dim] # q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim] @@ -454,18 +491,37 @@ class XAttentionBSAPolicy(SparsePolicy): q_reshaped_len = padded_q_len // self.stride - # Use a single slot for loading (synchronous mode for simplicity) + # Get block size from context + block_size = ctx.block_size # tokens per CPU block (e.g., 4096) + reshaped_block_size = block_size // self.stride # e.g., 4096/8 = 512 + + # ============================================================ + # Step 1: Compute chunk_start and related parameters + # ============================================================ + # chunk_start = Q's global position in reshaped space + # Q starts at position: num_historical_blocks * block_size + num_historical_blocks = len(available_blocks) + historical_k_len = num_historical_blocks * block_size + chunk_start = historical_k_len // self.stride # Q's position in reshaped space + chunk_end = chunk_start + q_reshaped_len + + # For valid Q length tracking (excluding padding) + valid_q_reshaped = (q_len + self.stride - 1) // self.stride + real_q_len = chunk_start + valid_q_reshaped + + # ============================================================ + # Step 2: Pipeline load historical K blocks and compute attn_scores + # ============================================================ + # Key design: Load each block, compute immediately, then release + # This avoids storing all K in GPU memory at once (offload-friendly) slot = 0 attn_scores_list = [] + BLOCK_N = 128 + k_alignment = self.stride * BLOCK_N - # Get block size from context - block_size = ctx.block_size # tokens per CPU block (e.g., 1024) - reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128 - - with nvtx.range("xattn_estimate_gemm"): + with nvtx.range("xattn_estimate_historical"): for cpu_block_id in available_blocks: # Load only K from CPU to GPU (V not needed for estimate) - # This saves 50% communication in the estimate phase offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.wait_slot_layer(slot) @@ -473,125 +529,228 @@ class XAttentionBSAPolicy(SparsePolicy): k_block = offload_engine.get_k_for_slot(slot) # Convert K to [batch, heads, k_len, head_dim] - # k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim] - K_chunk = k_block.transpose(1, 2) - + K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim] + # Handle GQA: expand K heads to match Q heads num_kv_heads = K_chunk.shape[1] if num_heads != num_kv_heads: num_groups = num_heads // num_kv_heads K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) + + #> DEBUG: save all K cache + start_pos = cpu_block_id * block_size + self._debug_k_full[:, :, start_pos:start_pos + block_size, :].copy_(K_chunk) + + # # Pad K if necessary + # k_len = K_chunk.shape[2] + # if k_len < k_alignment: + # pad_size = k_alignment - k_len + # K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0) - # Pad K if necessary (k_len must be divisible by stride * BLOCK_N) - k_len = K_chunk.shape[2] - BLOCK_N = 128 - k_alignment = self.stride * BLOCK_N - if k_len < k_alignment: - # K too short, pad it - pad_size = k_alignment - k_len - K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0) - - # Compute attention scores using flat_group_gemm_fuse_reshape - # Output: [batch, heads, q_len/stride, k_len/stride] - attn_chunk = flat_group_gemm_fuse_reshape( - Q, K_chunk, self.stride, - chunk_start=0, - chunk_end=q_reshaped_len, - is_causal=False - ) - attn_scores_list.append(attn_chunk) + # # Compute attention scores for this historical block + # # Historical blocks: all positions < Q, so Q always sees them (full attention) + # # Use LOCAL chunk_start=0 to match test_xattn_k_chunked.py behavior + # attn_chunk = flat_group_gemm_fuse_reshape( + # Q, K_chunk, self.stride, + # chunk_start=0, # Local: same as test + # chunk_end=q_reshaped_len, + # is_causal=False, # Historical K: all visible to Q + # ) + # attn_scores_list.append(attn_chunk) # Mark slot as done for reuse offload_engine.record_slot_compute_done(slot) + + num_kv_heads = k.shape[1] + if num_heads != num_kv_heads: + num_groups = num_heads // num_kv_heads + k_repeated = k.repeat_interleave(num_groups, dim=1).unsqueeze(0).transpose(1, 2) # [1, num_heads, historical_k_len, head_dim] + + self._debug_k_full[:, :, historical_k_len:historical_k_len + q_len, :].copy_(k_repeated) + + if layer_id == 0: + __import__('pdb').set_trace() - # Concatenate all attention scores along K dimension - # Each chunk: [1, heads, q_reshaped_len, block_reshaped_len] - # Result: [1, heads, q_reshaped_len, total_k_reshaped_len] + # ============================================================ + # Step 3: Get current chunk K and compute its attn_scores + # ============================================================ + with nvtx.range("xattn_estimate_current"): + # Current chunk K is in prefill buffer (already on GPU) + k_curr, _ = offload_engine.get_prefill_buffer_slice(layer_id, q_len) + # k_curr: [1, q_len, num_kv_heads, head_dim] -> [1, num_kv_heads, q_len, head_dim] + K_current = k_curr.transpose(1, 2) + + # Handle GQA for current chunk K + num_kv_heads = K_current.shape[1] + if num_heads != num_kv_heads: + num_groups = num_heads // num_kv_heads + K_current = K_current.repeat_interleave(num_groups, dim=1) + + # Pad current K if necessary + curr_k_len = K_current.shape[2] + padded_curr_k_len = ((curr_k_len + k_alignment - 1) // k_alignment) * k_alignment + if padded_curr_k_len != curr_k_len: + pad_size = padded_curr_k_len - curr_k_len + K_current = torch.nn.functional.pad(K_current, (0, 0, 0, pad_size), value=0) + + # Compute attention scores for current chunk + # IMPORTANT: Use LOCAL coordinates (0 to q_reshaped_len) for current chunk! + # Because K_current only contains current chunk K (not full sequence), + # block_n in kernel starts from 0. Using global chunk_start would cause + # incorrect causal mask (Q would see K blocks it shouldn't). + attn_current = flat_group_gemm_fuse_reshape( + Q, K_current, self.stride, + chunk_start=0, # Local: Q starts at 0 relative to K_current + chunk_end=q_reshaped_len, # Local: Q ends at q_reshaped_len + is_causal=True, # Current chunk: apply causal mask + ) + attn_scores_list.append(attn_current) + del K_current + + # ============================================================ + # Step 4: Concatenate all attn_scores + # ============================================================ if not attn_scores_list: return available_blocks attn_scores = torch.cat(attn_scores_list, dim=-1) - # Free intermediate list immediately del attn_scores_list - # Step 2: Apply softmax_fuse_block_sum with hierarchical aggregation - # Use smaller estimate_block_size (1024) for 15x faster softmax kernel, - # then aggregate to CPU block level (4096). - # - # Hierarchical approach: - # 1. softmax_fuse_block_sum with estimate_block_size (1024) -> fine-grained scores - # 2. Aggregate: reshape + sum -> CPU block level scores - # 3. Select blocks based on score + threshold (NOT mask + voting) - cpu_block_size = block_size # e.g., 4096 - estimate_bs = self.estimate_block_size # e.g., 1024 (15x faster) - ratio = cpu_block_size // estimate_bs # e.g., 4 + # Calculate padded K length for later use + padded_k_len = historical_k_len + padded_curr_k_len - # Use estimate_block_size for softmax kernel (optimized) - reshaped_est_bs = estimate_bs // self.stride # e.g., 1024/8 = 128 - norm = 1.0 # Normalization factor - scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling - segment_size = min(4096, reshaped_est_bs) + # ============================================================ + # Step 5: Apply softmax_fuse_block_sum with causal=True + # ============================================================ + cpu_block_size = block_size # e.g., 4096 + bsa_per_cpu = cpu_block_size // self.BSA_BLOCK_SIZE # e.g., 4096/128 = 32 + + # Use BSA_BLOCK_SIZE for block aggregation (aligned with GPU-only) + reshaped_bsa_bs = self.BSA_BLOCK_SIZE // self.stride # e.g., 128/8 = 16 + norm = 1.0 + scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm + segment_size = min(4096, reshaped_bsa_bs) with nvtx.range("xattn_estimate_softmax"): - block_sums_fine = softmax_fuse_block_sum( + block_sums = softmax_fuse_block_sum( attn_scores, - reshaped_est_bs, # Use optimized estimate block size (128 vs 512) + reshaped_bsa_bs, segment_size, - chunk_start=0, - chunk_end=q_reshaped_len, - real_q_len=q_reshaped_len, + chunk_start=chunk_start, + chunk_end=chunk_end, + real_q_len=real_q_len, scale=scale, - is_causal=False, # Historical blocks are all before current chunk + is_causal=True, # Causal for consistent with GPU-only ) - # block_sums_fine shape: [batch, heads, q_est_blocks, k_est_blocks] - # where k_est_blocks = len(available_blocks) * ratio + # block_sums shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks] - # Step 3: Aggregate to CPU block level (hierarchical sum) - # This is mathematically equivalent to direct computation but much faster - batch_size_bs, num_heads_bs, q_est_blocks, k_est_blocks = block_sums_fine.shape - num_cpu_blocks = len(available_blocks) + # ============================================================ + # Step 6: Use find_blocks_chunked to generate BSA-level mask + # ============================================================ + # Calculate BSA block indices + q_bsa_blocks = (padded_q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE + total_k_bsa_blocks = (padded_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE + historical_k_bsa_blocks = num_historical_blocks * bsa_per_cpu - with nvtx.range("xattn_estimate_aggregate"): - # Reshape: [batch, heads, q_est, k_est] -> [batch, heads, q_est, num_cpu, ratio] - block_sums_coarse = block_sums_fine.view( - batch_size_bs, num_heads_bs, q_est_blocks, num_cpu_blocks, ratio - ).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks] + # current_index for find_blocks_chunked: Q's block offset + q_start_bsa_block = historical_k_bsa_blocks # Q starts after historical K - # Sum over Q dimension to get total attention from Q chunk to each K block - cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks] + with nvtx.range("xattn_find_blocks"): + mask = find_blocks_chunked( + block_sums, + current_index=q_start_bsa_block, # Q's position in BSA blocks + threshold=self.threshold, + num_to_choose=None, + decoding=False, + mode="prefill", + causal=True, # Causal for block-level mask + ) + # mask shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks] - # Step 4: Select blocks using score + threshold (replaces mask + majority voting) - # This is simpler and more direct than the original mask-based approach - with nvtx.range("xattn_estimate_select"): - # Average scores across heads (GQA-aware: all heads contribute equally) - scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks] + # ============================================================ + # Step 7: Extract mask portions and record density + # ============================================================ + B, H, Q_bsa, K_bsa_total = mask.shape - # Normalize to get attention distribution - total_score = scores_per_block.sum() - if total_score > 0: - score_ratio = scores_per_block / total_score - else: - # Edge case: all zeros, select all blocks - selected_block_ids = list(available_blocks) - if layer_id == 0 and available_blocks: - self._stats_total_available_blocks += len(available_blocks) - self._stats_total_selected_blocks += len(selected_block_ids) - self._stats_num_chunks += 1 - return selected_block_ids + # Calculate valid Q blocks (excluding padding) + valid_q_bsa = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE + valid_curr_k_bsa = (curr_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE - # Sort by score (descending) and select until threshold is reached - sorted_indices = torch.argsort(score_ratio, descending=True) - cumsum = 0.0 - selected_indices = set() + # 7a: Record historical blocks density + # IMPORTANT: For historical blocks, apply causal mask to match GPU-only density calculation! + # Q block i (global position = q_start_bsa_block + i) can see historical K block j + # only if j <= q_start_bsa_block + i (causal constraint) + mask_historical = mask[:, :, :valid_q_bsa, :historical_k_bsa_blocks] - for idx in sorted_indices.tolist(): - selected_indices.add(idx) - cumsum += score_ratio[idx].item() - if cumsum >= self.threshold: - break + if historical_k_bsa_blocks > 0: + # Create causal mask for historical blocks + # Q_global[i] = q_start_bsa_block + i, K[j] = j + # Causal: j <= Q_global[i] => j <= q_start_bsa_block + i + q_global_indices = torch.arange(valid_q_bsa, device=mask.device) + q_start_bsa_block + k_indices = torch.arange(historical_k_bsa_blocks, device=mask.device) + # Q at position q_global_indices[i] can see K at position k_indices[j] if k_indices[j] <= q_global_indices[i] + causal_mask_historical = k_indices.unsqueeze(0) <= q_global_indices.unsqueeze(1) # [valid_q_bsa, historical_k_bsa_blocks] - # Map indices back to block IDs - selected_block_ids = [available_blocks[i] for i in sorted(selected_indices)] + # Count positions within causal mask only + total_historical_causal = causal_mask_historical.sum().item() * B * H + selected_historical = (mask_historical & causal_mask_historical.unsqueeze(0).unsqueeze(0)).sum().item() + + if total_historical_causal > 0: + DensityObserver.record_counts(layer_id, selected_historical, total_historical_causal) + + # 7b: Record current chunk density (causal, to align with GPU-only mode) + # Current chunk is the portion after historical blocks + if valid_curr_k_bsa > 0: + # Extract current chunk mask (only valid portion, not padded) + mask_current = mask[:, :, :valid_q_bsa, historical_k_bsa_blocks:historical_k_bsa_blocks + valid_curr_k_bsa] + + q_dim = mask_current.shape[2] + k_dim = mask_current.shape[3] + + # Create causal mask (lower triangular) + # For current chunk: Q[i] can see K[j] where j <= i (standard causal) + causal_mask = torch.tril(torch.ones(q_dim, k_dim, device=mask.device, dtype=torch.bool)) + + # Count positions within causal mask only + total_current_causal = causal_mask.sum().item() * B * H + selected_current = (mask_current & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() + + if total_current_causal > 0: + DensityObserver.record_counts(layer_id, selected_current, total_current_causal) + + # Step 7.5: Save historical mask to pre-allocated buffer for compute_chunked_prefill + # Use full Q_bsa (padded) for buffer, not valid_q_bsa + mask_historical_full = mask[:, :, :, :historical_k_bsa_blocks] + if self._prefill_mask_buffer is not None: + # Only save historical portion of mask + self._prefill_mask_buffer[:, :, :Q_bsa, :historical_k_bsa_blocks].copy_(mask_historical_full) + self._current_mask_q_bsa = Q_bsa + self._current_mask_k_bsa = historical_k_bsa_blocks + + # ============================================================ + # Step 8: Aggregate mask to CPU block level (union of heads) + # ============================================================ + # Only aggregate historical blocks (current chunk is always full attention) + num_cpu_blocks = num_historical_blocks + + with nvtx.range("xattn_aggregate_mask"): + # Reshape historical mask: [B, H, Q_bsa, historical_k_bsa] -> [B, H, Q_bsa, num_cpu, bsa_per_cpu] + # Use full Q_bsa (not valid_q_bsa) for aggregation + mask_per_cpu = mask_historical_full.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu) + + # Union across: bsa_per_cpu, Q_bsa, heads -> [B, num_cpu] + cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1) # [B, num_cpu] + + # Get selected indices + selected_indices = cpu_needed[0].nonzero().squeeze(-1).tolist() + if isinstance(selected_indices, int): + selected_indices = [selected_indices] + + # Handle empty available_blocks case (first chunk) + if available_blocks: + selected_block_ids = [available_blocks[i] for i in selected_indices] + else: + selected_block_ids = [] # Always include first block (sink) and last block for safety if available_blocks and available_blocks[0] not in selected_block_ids: @@ -599,6 +758,14 @@ class XAttentionBSAPolicy(SparsePolicy): if available_blocks and available_blocks[-1] not in selected_block_ids: selected_block_ids.append(available_blocks[-1]) + # Record communication density (CPU block granularity) - only if there are historical blocks + if available_blocks: + DensityObserver.record_comm_density( + layer_id, + selected_cpu_blocks=len(selected_block_ids), + total_cpu_blocks=len(available_blocks), + ) + # Update statistics (only for layer 0 to avoid overcounting) if layer_id == 0 and available_blocks: self._stats_total_available_blocks += len(available_blocks) @@ -611,7 +778,7 @@ class XAttentionBSAPolicy(SparsePolicy): f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}") # Free intermediate tensors to prevent memory leak - del attn_scores, block_sums_fine, block_sums_coarse, cpu_block_scores, scores_per_block + del attn_scores, block_sums, mask, mask_historical_full return selected_block_ids @@ -637,6 +804,10 @@ class XAttentionBSAPolicy(SparsePolicy): 2. Compute attention to current chunk 3. Merge all results + Note: The BSA-level mask is saved in self._prefill_mask_buffer by select_blocks(). + Currently we use flash_attn_with_lse for computation (supports LSE merge). + TODO: Optimize to use BSA kernel with the saved mask for per-head sparse attention. + Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer) @@ -667,6 +838,11 @@ class XAttentionBSAPolicy(SparsePolicy): # Use the pre-selected blocks directly cpu_block_table = selected_blocks + # Note: BSA mask is available in self._prefill_mask_buffer (saved by select_blocks) + # Mask shape: [1, num_heads, Q_bsa, K_bsa] where Q_bsa = self._current_mask_q_bsa + # Selected indices: self._selected_cpu_indices, bsa_per_cpu: self._bsa_per_cpu + # TODO: Use this mask with BSA kernel for per-head sparse attention optimization + if cpu_block_table: with nvtx.range("xattn_compute_historical"): load_slots = list(range(offload_engine.num_ring_slots)) diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index f3d3d1a..4ae437b 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -221,20 +221,19 @@ class Attention(nn.Module): cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) # Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill) - selected_blocks = [] - if cpu_block_table: - num_chunks = current_chunk_idx + 1 - policy_ctx = PolicyContext( - query_chunk_idx=current_chunk_idx, - num_query_chunks=num_chunks, - layer_id=self.layer_id, - query=q, # Pass query for sparse policies that need it - is_prefill=True, - block_size=kvcache_manager.block_size, - total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, - ) - selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx) - logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks") + # Always call select_blocks even for first chunk (cpu_block_table may be empty) + num_chunks = current_chunk_idx + 1 + policy_ctx = PolicyContext( + query_chunk_idx=current_chunk_idx, + num_query_chunks=num_chunks, + layer_id=self.layer_id, + query=q, # Pass query for sparse policies that need it + is_prefill=True, + block_size=kvcache_manager.block_size, + total_kv_len=len(cpu_block_table) * kvcache_manager.block_size if cpu_block_table else 0, + ) + selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k) + logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks") # [DEBUG] Verify execution path logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, " @@ -320,7 +319,7 @@ class Attention(nn.Module): block_size=kvcache_manager.block_size, total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, ) - selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx) + selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k) logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks") # [DEBUG] Verify execution path diff --git a/nanovllm/utils/density_observer.py b/nanovllm/utils/density_observer.py index 3537980..4bc14f0 100644 --- a/nanovllm/utils/density_observer.py +++ b/nanovllm/utils/density_observer.py @@ -1,13 +1,22 @@ """ DensityObserver - Sparse Attention Density 统计 Observer。 -统计每层的 sparse attention density: -- density = selected_blocks / total_causal_blocks -- 在 causal attention 下,只计算下三角区域 +统计两种 density: +1. Compute Density (计算密度): 基于 BSA block size (128) + - density = selected_bsa_blocks / total_causal_bsa_blocks + - GPU-only 和 Offload 模式应该一致 + +2. Communication Density (通信密度): 基于 CPU block size (如 4096) + - comm_density = selected_cpu_blocks / total_cpu_blocks + - 仅用于 Offload 模式,由于粒度更粗,必然 >= compute density 统计位置: -- GPU-only: xattn_bsa.py compute_prefill() -- Offload: xattn_bsa.py select_blocks() +- GPU-only: xattn_bsa.py compute_prefill() - 只记录 compute density +- Offload: xattn_bsa.py select_blocks() - 记录两种 density + +对于 Offload 模式的 Density 计算: +- 不是简单的 avg 或 min +- 而是 sum(selected) / sum(total),正确处理不同 chunk 大小的权重 """ from typing import List, Dict, Optional, Tuple @@ -26,16 +35,26 @@ class DensityObserver(Observer): DensityObserver.complete_reset() # ... run inference ... DensityObserver.record(layer_id, mask, causal=True) + # 或者使用累积模式 (offload): + DensityObserver.record_counts(layer_id, selected, total) # ... DensityObserver.print_summary() """ _enabled: bool = False # 默认禁用 - # 每层的 density 记录 + # 每层的 compute density 记录 (BSA block 粒度) # key: layer_id, value: list of density values (每次 prefill chunk 一个) _layer_densities: Dict[int, List[float]] = {} + # 每层的 communication density 记录 (CPU block 粒度,仅 offload 模式) + _layer_comm_densities: Dict[int, List[float]] = {} + + # 累积模式: 记录 selected/total counts (用于 offload 模式) + # 这样可以在所有 chunks 完成后正确计算 density = sum(selected) / sum(total) + _layer_selected_counts: Dict[int, List[int]] = {} + _layer_total_counts: Dict[int, List[int]] = {} + # Mask shape 记录 (用于调试) _last_q_blocks: int = 0 _last_k_blocks: int = 0 @@ -56,7 +75,7 @@ class DensityObserver(Observer): causal: bool = True, ) -> float: """ - 记录一层的 density。 + 记录一层的 density (适用于 GPU-only 模式)。 Args: layer_id: 层 ID @@ -82,6 +101,72 @@ class DensityObserver(Observer): return density + @classmethod + def record_counts( + cls, + layer_id: int, + selected_blocks: int, + total_blocks: int, + ) -> None: + """ + 记录一层的 selected/total block counts (适用于 offload 累积模式)。 + + 使用累积计数而不是直接计算 density,这样在所有 chunks 处理完后可以正确计算: + overall_density = sum(selected) / sum(total) + + 这比 avg(density) 更准确,因为不同 chunk 的 Q 和 K 长度不同。 + + Args: + layer_id: 层 ID + selected_blocks: 这个 chunk 选中的 blocks 数量 + total_blocks: 这个 chunk 的 total possible blocks 数量 + """ + if not cls._enabled: + return + + # 初始化列表 + if layer_id not in cls._layer_selected_counts: + cls._layer_selected_counts[layer_id] = [] + if layer_id not in cls._layer_total_counts: + cls._layer_total_counts[layer_id] = [] + + # 累积记录 + cls._layer_selected_counts[layer_id].append(selected_blocks) + cls._layer_total_counts[layer_id].append(total_blocks) + + @classmethod + def record_comm_density( + cls, + layer_id: int, + selected_cpu_blocks: int, + total_cpu_blocks: int, + ) -> float: + """ + 记录一层的 communication density (CPU block 粒度)。 + + Args: + layer_id: 层 ID + selected_cpu_blocks: 选中的 CPU blocks 数量 + total_cpu_blocks: 总 CPU blocks 数量 + + Returns: + communication density 值 + """ + if not cls._enabled: + return 0.0 + + if total_cpu_blocks == 0: + return 1.0 + + comm_density = selected_cpu_blocks / total_cpu_blocks + + # 记录 + if layer_id not in cls._layer_comm_densities: + cls._layer_comm_densities[layer_id] = [] + cls._layer_comm_densities[layer_id].append(comm_density) + + return comm_density + @classmethod def _compute_density(cls, mask: torch.Tensor, causal: bool) -> float: """计算 mask 的 density""" @@ -107,22 +192,63 @@ class DensityObserver(Observer): def complete_reset(cls) -> None: """重置所有统计""" cls._layer_densities = {} + cls._layer_comm_densities = {} + cls._layer_selected_counts = {} + cls._layer_total_counts = {} cls._last_q_blocks = 0 cls._last_k_blocks = 0 cls._mode = "unknown" @classmethod def get_per_layer_density(cls) -> Dict[int, float]: - """获取每层的平均 density""" + """ + 获取每层的 density。 + + 对于累积模式 (offload): density = sum(selected) / sum(total) + 对于直接记录模式 (gpu_only): density = avg(density_values) + """ result = {} - for layer_id, densities in cls._layer_densities.items(): - if densities: - result[layer_id] = sum(densities) / len(densities) + + # 优先使用累积模式 (offload) + if cls._layer_selected_counts: + for layer_id in cls._layer_selected_counts: + selected_list = cls._layer_selected_counts.get(layer_id, []) + total_list = cls._layer_total_counts.get(layer_id, []) + total_selected = sum(selected_list) + total_total = sum(total_list) + if total_total > 0: + result[layer_id] = total_selected / total_total + else: + # 直接记录模式 (gpu_only) + for layer_id, densities in cls._layer_densities.items(): + if densities: + result[layer_id] = sum(densities) / len(densities) + return result @classmethod def get_overall_density(cls) -> float: - """获取所有层的平均 density""" + """ + 获取所有层的总体 compute density。 + + 对于累积模式 (offload): density = sum(all_selected) / sum(all_total) + 对于直接记录模式 (gpu_only): density = avg(all_density_values) + + 注意: 总体 density 不是简单的 avg(per_layer_density), + 而是 sum(all_selected) / sum(all_total),这样可以正确处理权重。 + """ + # 优先使用累积模式 (offload) + if cls._layer_selected_counts: + total_selected = 0 + total_total = 0 + for layer_id in cls._layer_selected_counts: + total_selected += sum(cls._layer_selected_counts[layer_id]) + total_total += sum(cls._layer_total_counts.get(layer_id, [])) + if total_total > 0: + return total_selected / total_total + return 0.0 + + # 直接记录模式 (gpu_only) all_densities = [] for densities in cls._layer_densities.values(): all_densities.extend(densities) @@ -130,6 +256,16 @@ class DensityObserver(Observer): return 0.0 return sum(all_densities) / len(all_densities) + @classmethod + def get_overall_comm_density(cls) -> float: + """获取所有层的平均 communication density""" + all_densities = [] + for densities in cls._layer_comm_densities.values(): + all_densities.extend(densities) + if not all_densities: + return 0.0 + return sum(all_densities) / len(all_densities) + @classmethod def get_summary(cls) -> dict: """返回统计摘要""" @@ -160,8 +296,13 @@ class DensityObserver(Observer): per_layer = cls.get_per_layer_density() overall = cls.get_overall_density() min_layer, min_density = cls.get_min_density() + overall_comm = cls.get_overall_comm_density() print(f"[DensityObserver] Mode: {cls._mode}") - print(f" Overall density: {overall:.4f}") - print(f" Min density: {min_density:.4f} (layer {min_layer})") + print(f" Compute density: {overall:.4f} (min: {min_density:.4f} @ layer {min_layer})") + if overall_comm > 0: + print(f" Comm density: {overall_comm:.4f}") print(f" Num layers: {len(per_layer)}") + # 输出 layer 0 的 density 用于对比 + if 0 in per_layer: + print(f" Layer 0 density: {per_layer[0]:.6f}") diff --git a/tests/test_xattn_kernels.py b/tests/test_xattn_kernels.py index b4800c8..8e5fcfb 100644 --- a/tests/test_xattn_kernels.py +++ b/tests/test_xattn_kernels.py @@ -41,9 +41,9 @@ K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda() for i in range(q_len): if i % 2 == 0: - Q[0, 0, i, :] = 1 + Q[0, 0, i, :] = 1 * (i // stride + 1) else: - Q[0, 0, i, :] = 2 + Q[0, 0, i, :] = 2 * (i // stride + 1) for i in range(kv_len): if i % 2 == 0: @@ -74,8 +74,11 @@ for k_chunk_idx in range(num_k_chunks): Q, K_chunk, stride, chunk_start=0, chunk_end=q_reshaped_len, - is_causal=False + is_causal=True ) + + __import__('pdb').set_trace() + attn_scores_list.append(attn_chunk) # 拼接所有 K chunks 的结果