Compare commits

2 Commits

Author SHA1 Message Date
Zijie Tian
2e96d1d97d WIP: Enhance sparse attention with density tracking and block selection improvements
- Added analysis documentation for xattn density alignment.
- Refactored ModelRunner to pre-allocate policy metadata buffers regardless of CPU offload configuration.
- Updated FullAttentionPolicy and SparsePolicy to accept query and key tensors for block selection.
- Enhanced QuestPolicy to utilize query tensor for block selection and improved handling of selected blocks.
- Expanded XAttentionBSAPolicy to support chunked prefill and improved attention score computation with historical and current chunk handling.
- Introduced DensityObserver to track compute and communication density for sparse attention layers.
- Updated attention layer to ensure block selection is always called, improving robustness in first chunk scenarios.
- Added tests for attention kernel behavior with enhanced input patterns.
2026-01-31 14:48:23 +08:00
Zijie Tian
f6ac4ccdde feat: add DensityObserver for XAttention sparse attention density tracking
- Add DensityObserver class to track per-layer density statistics
- Integrate DensityObserver into compute_prefill for GPU-only mode
- Fix stride parameter not being passed to xattn_estimate
- Add density statistics output to test_ruler.py for XATTN_BSA
- Add comprehensive density benchmark documentation

Key changes:
- nanovllm/utils/density_observer.py: New Observer for density tracking
- xattn_bsa.py: Add stride param to xattn_estimate, integrate DensityObserver
- test_ruler.py: Enable DensityObserver and print summary for XATTN_BSA
- docs/xattn_density_benchmark.md: Benchmark results for 4K-32K contexts

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-30 16:26:56 +08:00
11 changed files with 863 additions and 145 deletions

View File

@@ -18,6 +18,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
| [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) | | [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) |
| [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 | | [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 |
| [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 | | [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 |
| [`docs/xattn_density_benchmark.md`](docs/xattn_density_benchmark.md) | 📊 XAttention Density Benchmark: 4K-32K context、stride 参数、per-layer density 分析 |
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 | | [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling | | [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) | | [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
@@ -36,6 +37,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
| [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x | | [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x |
| [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL)≤10B 推荐模型 | | [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL)≤10B 推荐模型 |
| [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 | | [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 |
| [`docs/xattn_density_alignment_analysis.md`](docs/xattn_density_alignment_analysis.md) | 📊 ANALYSIS: GPU-only vs Offload 模式 density 对齐分析chunked softmax 边界效应5-7% 差异根因 |
## Rules Index ## Rules Index

View File

@@ -0,0 +1,195 @@
# XAttention Density Benchmark
GPU-only 模式下 XAttention Block Sparse Attention 的 density 测试结果。
## 测试配置
| 参数 | 值 | 说明 |
|------|-----|------|
| Model | Llama-3.1-8B-Instruct | 32 layers, 32 heads, 8 KV heads |
| Block Size | 128 tokens | BSA kernel 固定要求 |
| Threshold | 0.9 / 0.95 | 累积注意力阈值 |
| Stride | 4 / 8 / 16 | Q/K 下采样步长 |
| Dataset | RULER niah_single_1 | Sample 0 |
| Mode | GPU-only | 无 CPU offload |
## Density 定义
```python
# Density = selected_blocks / total_causal_blocks
# 在 causal attention 下,只计算下三角区域的 blocks
# Overall density = 所有层的平均值
def compute_density(mask, causal=True):
"""
mask: [batch, heads, q_blocks, k_blocks] boolean tensor
"""
if causal:
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks))
total = causal_mask.sum() * batch * heads
selected = (mask & causal_mask).sum()
return selected / total
```
## 测试结果
### threshold=0.9
#### Overall Density (平均)
| Context | stride=4 | stride=8 | stride=16 |
|---------|----------|----------|-----------|
| **4K** | 0.5220 (52.2%) | 0.5292 (52.9%) | 0.5430 (54.3%) |
| **8K** | 0.5152 (51.5%) | 0.5252 (52.5%) | 0.5396 (54.0%) |
| **16K** | 0.4682 (46.8%) | 0.4775 (47.8%) | 0.4888 (48.9%) |
| **32K** | 0.3700 (37.0%) | 0.4012 (40.1%) | 0.4196 (42.0%) |
#### Min Density (per layer)
| Context | stride=4 | stride=8 | stride=16 |
|---------|----------|----------|-----------|
| **4K** | 0.2805 (Layer 3) | 0.3132 (Layer 3) | 0.3376 (Layer 5) |
| **8K** | 0.2886 (Layer 5) | 0.2725 (Layer 5) | 0.2995 (Layer 5) |
| **16K** | 0.2247 (Layer 5) | 0.2349 (Layer 5) | 0.2451 (Layer 5) |
| **32K** | 0.1799 (Layer 5) | 0.1846 (Layer 5) | 0.1964 (Layer 5) |
### threshold=0.95
#### Overall Density (平均)
| Context | stride=4 | stride=8 | stride=16 |
|---------|----------|----------|-----------|
| **4K** | 0.6561 (65.6%) | 0.6699 (67.0%) | 0.6815 (68.2%) |
| **8K** | 0.6462 (64.6%) | 0.6584 (65.8%) | 0.6732 (67.3%) |
| **16K** | 0.6004 (60.0%) | 0.6114 (61.1%) | 0.6193 (61.9%) |
| **32K** | 0.4894 (48.9%) | 0.5203 (52.0%) | 0.5385 (53.9%) |
#### Min Density (per layer)
| Context | stride=4 | stride=8 | stride=16 |
|---------|----------|----------|-----------|
| **4K** | 0.3972 (Layer 3) | 0.4348 (Layer 5) | 0.4517 (Layer 4) |
| **8K** | 0.4004 (Layer 5) | 0.3906 (Layer 5) | 0.4239 (Layer 5) |
| **16K** | 0.3331 (Layer 5) | 0.3453 (Layer 5) | 0.3589 (Layer 5) |
| **32K** | 0.2656 (Layer 5) | 0.2784 (Layer 5) | 0.2917 (Layer 5) |
### threshold 对比 (stride=8)
| Context | threshold=0.9 | threshold=0.95 | 差异 |
|---------|---------------|----------------|------|
| **4K** | 0.5292 (52.9%) | 0.6699 (67.0%) | -14.1% |
| **8K** | 0.5252 (52.5%) | 0.6584 (65.8%) | -13.3% |
| **16K** | 0.4775 (47.8%) | 0.6114 (61.1%) | -13.4% |
| **32K** | 0.4012 (40.1%) | 0.5203 (52.0%) | -11.9% |
## 关键发现
### 1. Context Length 影响最大
Density 随 context length 显著下降threshold=0.9, stride=8
- 4K: 52.9% density
- 8K: 52.5% density
- 16K: 47.8% density
- 32K: 40.1% density
**结论**: 长序列有更多稀疏化机会XAttention 的优势在长序列上更明显。
### 2. Threshold 影响显著
threshold=0.9 比 0.95 的 density 低约 12-14%
- 0.9 更激进,选择更少的 blocks
- 0.95 更保守,保留更多 blocks
- 两者准确性都不受影响RULER NIAH 全部 PASS
### 3. Stride 影响较小
同一 context 下,不同 stride 的 density 差异约 2-5%
- stride 越大 → density 略高(采样越粗糙,选择更保守)
- stride=4 最激进stride=16 最保守
### 4. Min Density 集中在中间层
- 大多数情况下 min density 出现在 Layer 5
- 中间层的稀疏性最高,首尾层相对密集
- 这符合 Transformer 注意力模式的一般规律
### 5. 最佳稀疏化配置
32K + stride=4 + threshold=0.9 达到最低 density
- Overall: **37.0%** (节省 63% 计算)
- Min: **18.0%** (Layer 5)
### 6. 准确性稳定
所有配置下 RULER NIAH 测试都 PASS (score=1.0),说明:
- threshold=0.9 和 0.95 都足够保守,不损失准确性
- 不同 stride 不影响最终结果
## 推荐配置
| 场景 | threshold | stride | 说明 |
|------|-----------|--------|------|
| 精度优先 | 0.95 | 8 | 保守配置density ~52-67% |
| 平衡 | 0.9 | 8 | 默认配置density ~40-53% |
| 性能优先 | 0.9 | 4 | 激进配置density ~37-52% |
## 测试命令
```bash
# 基本测试
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--data-dir tests/data/ruler_32k \
--datasets niah_single_1 \
--sample-indices 0 \
--max-model-len 33792 \
--sparse-policy XATTN_BSA \
--sparse-threshold 0.9 \
--sparse-stride 8 \
--gpu-utilization 0.85
# 参数说明
# --sparse-policy XATTN_BSA 启用 XAttention Block Sparse Attention
# --sparse-threshold 0.9 累积注意力阈值 (0.9-0.99)
# --sparse-stride 8 Q/K 下采样步长 (4/8/16)
```
## DensityObserver 使用
```python
from nanovllm.utils.density_observer import DensityObserver
# 启用并重置
DensityObserver.enable()
DensityObserver.complete_reset()
# ... 运行 inference (compute_prefill 自动记录) ...
# 获取结果
summary = DensityObserver.get_summary()
# {
# "mode": "gpu_only",
# "overall_density": 0.40, # 所有层的平均值
# "per_layer_density": {0: 0.55, 1: 0.45, ...},
# "num_layers": 32
# }
# 获取最低 density
min_layer, min_density = DensityObserver.get_min_density()
# 打印摘要
DensityObserver.print_summary()
# [DensityObserver] Mode: gpu_only
# Overall density: 0.4012
# Min density: 0.1846 (layer 5)
# Num layers: 32
```
## 相关文件
| 文件 | 说明 |
|------|------|
| `nanovllm/kvcache/sparse/xattn_bsa.py` | XAttention BSA Policy 实现 |
| `nanovllm/utils/density_observer.py` | Density 统计 Observer |
| `nanovllm/ops/xattn.py` | xattn_estimate 核心算法 |
| `tests/test_ruler.py` | RULER benchmark 测试脚本 |

View File

@@ -229,16 +229,16 @@ class ModelRunner:
# GPU-only mode: pre-allocate policy metadata buffers # GPU-only mode: pre-allocate policy metadata buffers
# This avoids dynamic GPU memory allocation during forward pass # This avoids dynamic GPU memory allocation during forward pass
if not config.enable_cpu_offload: # if not config.enable_cpu_offload:
num_heads = hf_config.num_attention_heads // self.world_size num_heads = hf_config.num_attention_heads // self.world_size
self.kvcache_manager.sparse_policy.alloc_policy_metadata( self.kvcache_manager.sparse_policy.alloc_policy_metadata(
num_heads=num_heads, num_heads=num_heads,
num_kv_heads=num_kv_heads, num_kv_heads=num_kv_heads,
head_dim=head_dim, head_dim=head_dim,
max_seq_len=config.max_model_len, max_seq_len=config.max_model_len,
dtype=hf_config.torch_dtype, dtype=hf_config.torch_dtype,
device=torch.device("cuda"), device=torch.device("cuda"),
) )
# Log policy info (handle both enum and None cases) # Log policy info (handle both enum and None cases)
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL" policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"

View File

@@ -47,6 +47,8 @@ class FullAttentionPolicy(SparsePolicy):
available_blocks: List[int], available_blocks: List[int],
offload_engine: "OffloadEngine", offload_engine: "OffloadEngine",
ctx: PolicyContext, ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]: ) -> List[int]:
"""Return all blocks - no sparsity.""" """Return all blocks - no sparsity."""
# Update statistics (only for layer 0 to avoid overcounting) # Update statistics (only for layer 0 to avoid overcounting)

View File

@@ -142,6 +142,8 @@ class SparsePolicy(ABC):
available_blocks: List[int], available_blocks: List[int],
offload_engine: "OffloadEngine", offload_engine: "OffloadEngine",
ctx: PolicyContext, ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]: ) -> List[int]:
""" """
Select which KV blocks to load for the current query chunk. Select which KV blocks to load for the current query chunk.
@@ -158,6 +160,8 @@ class SparsePolicy(ABC):
to load KV to make selection decisions). to load KV to make selection decisions).
ctx: PolicyContext with information about the current query ctx: PolicyContext with information about the current query
chunk, layer, phase (prefill/decode), etc. chunk, layer, phase (prefill/decode), etc.
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk
Returns: Returns:
List of block IDs to load (must be a subset of available_blocks). List of block IDs to load (must be a subset of available_blocks).

View File

@@ -191,13 +191,26 @@ class QuestPolicy(SparsePolicy):
def select_blocks( def select_blocks(
self, self,
available_blocks: List[int], available_blocks: List[int],
offload_engine: "OffloadEngine",
ctx: PolicyContext, ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]: ) -> List[int]:
""" """
Select Top-K blocks based on query-key similarity bounds. Select Top-K blocks based on query-key similarity bounds.
If query is not available (some prefill scenarios), falls back If query is not available (some prefill scenarios), falls back
to loading all blocks. to loading all blocks.
Args:
available_blocks: List of CPU block IDs
offload_engine: OffloadEngine for loading KV (unused in Quest)
ctx: PolicyContext with metadata
q: Query tensor [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused in Quest, uses metadata instead)
Returns:
Selected block IDs
""" """
if self.metadata is None: if self.metadata is None:
raise RuntimeError( raise RuntimeError(
@@ -211,7 +224,7 @@ class QuestPolicy(SparsePolicy):
if n <= self.config.threshold_blocks: if n <= self.config.threshold_blocks:
return available_blocks return available_blocks
if ctx.query is None: if q is None:
# No query available - cannot compute scores # No query available - cannot compute scores
return available_blocks return available_blocks
@@ -221,11 +234,10 @@ class QuestPolicy(SparsePolicy):
) )
# Metadata is already on GPU, same device as query # Metadata is already on GPU, same device as query
device = ctx.query.device device = q.device
# Compute upper bound scores # Compute upper bound scores
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim] # query shape: [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
q = ctx.query
if q.dim() == 4: if q.dim() == 4:
# Prefill: use mean over sequence length # Prefill: use mean over sequence length
q = q.mean(dim=1) # [1, num_heads, head_dim] q = q.mean(dim=1) # [1, num_heads, head_dim]

View File

@@ -17,6 +17,7 @@ import torch.cuda.nvtx as nvtx
from typing import List, Tuple, TYPE_CHECKING from typing import List, Tuple, TYPE_CHECKING
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.utils.density_observer import DensityObserver
if TYPE_CHECKING: if TYPE_CHECKING:
from nanovllm.kvcache.offload_engine import OffloadEngine from nanovllm.kvcache.offload_engine import OffloadEngine
@@ -134,6 +135,21 @@ class XAttentionBSAPolicy(SparsePolicy):
self._v_expanded: torch.Tensor | None = None self._v_expanded: torch.Tensor | None = None
self._max_seq_len: int = 0 self._max_seq_len: int = 0
# Pre-allocated mask buffer for chunked prefill (offload mode)
# Stores BSA-level mask from select_blocks for use in compute_chunked_prefill
# Shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
self._prefill_mask_buffer: torch.Tensor | None = None
self._current_mask_q_bsa: int = 0 # Current Q BSA blocks in buffer
self._current_mask_k_bsa: int = 0 # Current K BSA blocks in buffer
# Selected block indices for mask extraction in compute_chunked_prefill
# Stores the indices of selected CPU blocks in available_blocks
self._selected_cpu_indices: List[int] = []
self._bsa_per_cpu: int = 0 # BSA blocks per CPU block
#> Debug: store all K cache
self._debug_k_full: torch.Tensor | None = None
def alloc_policy_metadata( def alloc_policy_metadata(
self, self,
num_heads: int, num_heads: int,
@@ -161,7 +177,17 @@ class XAttentionBSAPolicy(SparsePolicy):
dtype: Data type dtype: Data type
device: Target device device: Target device
""" """
# Only allocate if GQA (num_heads != num_kv_heads) # Pre-allocate mask buffer for chunked prefill (offload mode)
# mask shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
# This is needed regardless of GQA
max_q_bsa_blocks = self.chunk_size // self.BSA_BLOCK_SIZE
max_k_bsa_blocks = max_seq_len // self.BSA_BLOCK_SIZE
mask_shape = (1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks)
self._prefill_mask_buffer = torch.empty(mask_shape, dtype=torch.bool, device=device)
mask_memory_mb = num_heads * max_q_bsa_blocks * max_k_bsa_blocks / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated mask buffer: shape={mask_shape}, memory={mask_memory_mb:.1f} MB")
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads)
if num_heads == num_kv_heads: if num_heads == num_kv_heads:
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})") logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
return return
@@ -176,6 +202,9 @@ class XAttentionBSAPolicy(SparsePolicy):
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024) memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB") logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
#DEBUG : buffer for save all K cache.
self._debug_k_full = torch.empty((1, num_heads, max_seq_len, head_dim), dtype=dtype, device=device)
# ========================================================================= # =========================================================================
# GPU-only methods (non-chunked) # GPU-only methods (non-chunked)
# ========================================================================= # =========================================================================
@@ -258,6 +287,10 @@ class XAttentionBSAPolicy(SparsePolicy):
from nanovllm.ops.xattn import xattn_estimate from nanovllm.ops.xattn import xattn_estimate
# Set DensityObserver mode on first layer
if layer_id == 0:
DensityObserver.set_mode("gpu_only")
# Get dimensions # Get dimensions
total_q, num_heads, head_dim = q.shape total_q, num_heads, head_dim = q.shape
total_kv, num_kv_heads, _ = k.shape total_kv, num_kv_heads, _ = k.shape
@@ -315,6 +348,7 @@ class XAttentionBSAPolicy(SparsePolicy):
Q, K_exp, Q, K_exp,
chunk_size=self.chunk_size, chunk_size=self.chunk_size,
block_size=self.BSA_BLOCK_SIZE, block_size=self.BSA_BLOCK_SIZE,
stride=self.stride,
threshold=self.threshold, threshold=self.threshold,
use_triton=self.use_triton, use_triton=self.use_triton,
causal=True, causal=True,
@@ -360,13 +394,8 @@ class XAttentionBSAPolicy(SparsePolicy):
is_causal=True, is_causal=True,
) )
# Update statistics (layer 0 only to avoid overcounting) # Record density for all layers via DensityObserver
if layer_id == 0: DensityObserver.record(layer_id, mask_trimmed, causal=True)
selected_blocks = mask_trimmed.sum().item()
total_blocks = q_block_num * k_block_num * num_heads
density = selected_blocks / total_blocks if total_blocks > 0 else 1.0
logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, "
f"k_blocks={k_block_num}, density={density:.1%}")
return output return output
@@ -400,33 +429,42 @@ class XAttentionBSAPolicy(SparsePolicy):
available_blocks: List[int], available_blocks: List[int],
offload_engine: "OffloadEngine", offload_engine: "OffloadEngine",
ctx: PolicyContext, ctx: PolicyContext,
q: torch.Tensor,
k: torch.Tensor,
) -> List[int]: ) -> List[int]:
""" """
Compute attention scores for all available blocks using flat_group_gemm, Compute attention scores for all available blocks using flat_group_gemm,
then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks. then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks.
This method: This method aligns with GPU-only xattn_estimate_chunked:
1. Loads each K block from CPU 1. Loads each K block from CPU (historical blocks)
2. Computes Q@K^T attention scores using XAttention stride reshape 2. Gets current chunk K from prefill buffer
3. Applies softmax_fuse_block_sum to get block-level attention 3. Concatenates [historical K, current chunk K] for correct softmax normalization
4. Uses find_blocks_chunked to select blocks based on threshold 4. Uses causal=True with correct chunk_start for position-aware masking
5. Only selects from historical blocks (current chunk is always full attention)
Args: Args:
available_blocks: List of CPU block IDs available_blocks: List of CPU block IDs (historical blocks only)
offload_engine: OffloadEngine for loading blocks offload_engine: OffloadEngine for loading blocks
ctx: PolicyContext with query tensor and metadata ctx: PolicyContext with metadata
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk (used for estimation)
Returns: Returns:
Selected block IDs based on attention threshold Selected block IDs based on attention threshold
""" """
if not available_blocks or ctx.query is None: if q is None:
return available_blocks return available_blocks
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked
import math import math
layer_id = ctx.layer_id layer_id = ctx.layer_id
q = ctx.query # [seq_len, num_heads, head_dim] # Use passed q parameter instead of ctx.query
# Set DensityObserver mode on first layer
if layer_id == 0:
DensityObserver.set_mode("offload")
# Convert Q to [batch, heads, seq_len, head_dim] # Convert Q to [batch, heads, seq_len, head_dim]
# q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim] # q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim]
@@ -453,18 +491,37 @@ class XAttentionBSAPolicy(SparsePolicy):
q_reshaped_len = padded_q_len // self.stride q_reshaped_len = padded_q_len // self.stride
# Use a single slot for loading (synchronous mode for simplicity) # Get block size from context
block_size = ctx.block_size # tokens per CPU block (e.g., 4096)
reshaped_block_size = block_size // self.stride # e.g., 4096/8 = 512
# ============================================================
# Step 1: Compute chunk_start and related parameters
# ============================================================
# chunk_start = Q's global position in reshaped space
# Q starts at position: num_historical_blocks * block_size
num_historical_blocks = len(available_blocks)
historical_k_len = num_historical_blocks * block_size
chunk_start = historical_k_len // self.stride # Q's position in reshaped space
chunk_end = chunk_start + q_reshaped_len
# For valid Q length tracking (excluding padding)
valid_q_reshaped = (q_len + self.stride - 1) // self.stride
real_q_len = chunk_start + valid_q_reshaped
# ============================================================
# Step 2: Pipeline load historical K blocks and compute attn_scores
# ============================================================
# Key design: Load each block, compute immediately, then release
# This avoids storing all K in GPU memory at once (offload-friendly)
slot = 0 slot = 0
attn_scores_list = [] attn_scores_list = []
BLOCK_N = 128
k_alignment = self.stride * BLOCK_N
# Get block size from context with nvtx.range("xattn_estimate_historical"):
block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
with nvtx.range("xattn_estimate_gemm"):
for cpu_block_id in available_blocks: for cpu_block_id in available_blocks:
# Load only K from CPU to GPU (V not needed for estimate) # Load only K from CPU to GPU (V not needed for estimate)
# This saves 50% communication in the estimate phase
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
offload_engine.wait_slot_layer(slot) offload_engine.wait_slot_layer(slot)
@@ -472,8 +529,7 @@ class XAttentionBSAPolicy(SparsePolicy):
k_block = offload_engine.get_k_for_slot(slot) k_block = offload_engine.get_k_for_slot(slot)
# Convert K to [batch, heads, k_len, head_dim] # Convert K to [batch, heads, k_len, head_dim]
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim] K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
K_chunk = k_block.transpose(1, 2)
# Handle GQA: expand K heads to match Q heads # Handle GQA: expand K heads to match Q heads
num_kv_heads = K_chunk.shape[1] num_kv_heads = K_chunk.shape[1]
@@ -481,116 +537,220 @@ class XAttentionBSAPolicy(SparsePolicy):
num_groups = num_heads // num_kv_heads num_groups = num_heads // num_kv_heads
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
# Pad K if necessary (k_len must be divisible by stride * BLOCK_N) #> DEBUG: save all K cache
k_len = K_chunk.shape[2] start_pos = cpu_block_id * block_size
BLOCK_N = 128 self._debug_k_full[:, :, start_pos:start_pos + block_size, :].copy_(K_chunk)
k_alignment = self.stride * BLOCK_N
if k_len < k_alignment:
# K too short, pad it
pad_size = k_alignment - k_len
K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
# Compute attention scores using flat_group_gemm_fuse_reshape # # Pad K if necessary
# Output: [batch, heads, q_len/stride, k_len/stride] # k_len = K_chunk.shape[2]
attn_chunk = flat_group_gemm_fuse_reshape( # if k_len < k_alignment:
Q, K_chunk, self.stride, # pad_size = k_alignment - k_len
chunk_start=0, # K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
chunk_end=q_reshaped_len,
is_causal=False # # Compute attention scores for this historical block
) # # Historical blocks: all positions < Q, so Q always sees them (full attention)
attn_scores_list.append(attn_chunk) # # Use LOCAL chunk_start=0 to match test_xattn_k_chunked.py behavior
# attn_chunk = flat_group_gemm_fuse_reshape(
# Q, K_chunk, self.stride,
# chunk_start=0, # Local: same as test
# chunk_end=q_reshaped_len,
# is_causal=False, # Historical K: all visible to Q
# )
# attn_scores_list.append(attn_chunk)
# Mark slot as done for reuse # Mark slot as done for reuse
offload_engine.record_slot_compute_done(slot) offload_engine.record_slot_compute_done(slot)
# Concatenate all attention scores along K dimension num_kv_heads = k.shape[1]
# Each chunk: [1, heads, q_reshaped_len, block_reshaped_len] if num_heads != num_kv_heads:
# Result: [1, heads, q_reshaped_len, total_k_reshaped_len] num_groups = num_heads // num_kv_heads
k_repeated = k.repeat_interleave(num_groups, dim=1).unsqueeze(0).transpose(1, 2) # [1, num_heads, historical_k_len, head_dim]
self._debug_k_full[:, :, historical_k_len:historical_k_len + q_len, :].copy_(k_repeated)
if layer_id == 0:
__import__('pdb').set_trace()
# ============================================================
# Step 3: Get current chunk K and compute its attn_scores
# ============================================================
with nvtx.range("xattn_estimate_current"):
# Current chunk K is in prefill buffer (already on GPU)
k_curr, _ = offload_engine.get_prefill_buffer_slice(layer_id, q_len)
# k_curr: [1, q_len, num_kv_heads, head_dim] -> [1, num_kv_heads, q_len, head_dim]
K_current = k_curr.transpose(1, 2)
# Handle GQA for current chunk K
num_kv_heads = K_current.shape[1]
if num_heads != num_kv_heads:
num_groups = num_heads // num_kv_heads
K_current = K_current.repeat_interleave(num_groups, dim=1)
# Pad current K if necessary
curr_k_len = K_current.shape[2]
padded_curr_k_len = ((curr_k_len + k_alignment - 1) // k_alignment) * k_alignment
if padded_curr_k_len != curr_k_len:
pad_size = padded_curr_k_len - curr_k_len
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, pad_size), value=0)
# Compute attention scores for current chunk
# IMPORTANT: Use LOCAL coordinates (0 to q_reshaped_len) for current chunk!
# Because K_current only contains current chunk K (not full sequence),
# block_n in kernel starts from 0. Using global chunk_start would cause
# incorrect causal mask (Q would see K blocks it shouldn't).
attn_current = flat_group_gemm_fuse_reshape(
Q, K_current, self.stride,
chunk_start=0, # Local: Q starts at 0 relative to K_current
chunk_end=q_reshaped_len, # Local: Q ends at q_reshaped_len
is_causal=True, # Current chunk: apply causal mask
)
attn_scores_list.append(attn_current)
del K_current
# ============================================================
# Step 4: Concatenate all attn_scores
# ============================================================
if not attn_scores_list: if not attn_scores_list:
return available_blocks return available_blocks
attn_scores = torch.cat(attn_scores_list, dim=-1) attn_scores = torch.cat(attn_scores_list, dim=-1)
# Free intermediate list immediately
del attn_scores_list del attn_scores_list
# Step 2: Apply softmax_fuse_block_sum with hierarchical aggregation # Calculate padded K length for later use
# Use smaller estimate_block_size (1024) for 15x faster softmax kernel, padded_k_len = historical_k_len + padded_curr_k_len
# then aggregate to CPU block level (4096).
#
# Hierarchical approach:
# 1. softmax_fuse_block_sum with estimate_block_size (1024) -> fine-grained scores
# 2. Aggregate: reshape + sum -> CPU block level scores
# 3. Select blocks based on score + threshold (NOT mask + voting)
cpu_block_size = block_size # e.g., 4096
estimate_bs = self.estimate_block_size # e.g., 1024 (15x faster)
ratio = cpu_block_size // estimate_bs # e.g., 4
# Use estimate_block_size for softmax kernel (optimized) # ============================================================
reshaped_est_bs = estimate_bs // self.stride # e.g., 1024/8 = 128 # Step 5: Apply softmax_fuse_block_sum with causal=True
norm = 1.0 # Normalization factor # ============================================================
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling cpu_block_size = block_size # e.g., 4096
segment_size = min(4096, reshaped_est_bs) bsa_per_cpu = cpu_block_size // self.BSA_BLOCK_SIZE # e.g., 4096/128 = 32
# Use BSA_BLOCK_SIZE for block aggregation (aligned with GPU-only)
reshaped_bsa_bs = self.BSA_BLOCK_SIZE // self.stride # e.g., 128/8 = 16
norm = 1.0
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm
segment_size = min(4096, reshaped_bsa_bs)
with nvtx.range("xattn_estimate_softmax"): with nvtx.range("xattn_estimate_softmax"):
block_sums_fine = softmax_fuse_block_sum( block_sums = softmax_fuse_block_sum(
attn_scores, attn_scores,
reshaped_est_bs, # Use optimized estimate block size (128 vs 512) reshaped_bsa_bs,
segment_size, segment_size,
chunk_start=0, chunk_start=chunk_start,
chunk_end=q_reshaped_len, chunk_end=chunk_end,
real_q_len=q_reshaped_len, real_q_len=real_q_len,
scale=scale, scale=scale,
is_causal=False, # Historical blocks are all before current chunk is_causal=True, # Causal for consistent with GPU-only
) )
# block_sums_fine shape: [batch, heads, q_est_blocks, k_est_blocks] # block_sums shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
# where k_est_blocks = len(available_blocks) * ratio
# Step 3: Aggregate to CPU block level (hierarchical sum) # ============================================================
# This is mathematically equivalent to direct computation but much faster # Step 6: Use find_blocks_chunked to generate BSA-level mask
batch_size_bs, num_heads_bs, q_est_blocks, k_est_blocks = block_sums_fine.shape # ============================================================
num_cpu_blocks = len(available_blocks) # Calculate BSA block indices
q_bsa_blocks = (padded_q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
total_k_bsa_blocks = (padded_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
historical_k_bsa_blocks = num_historical_blocks * bsa_per_cpu
with nvtx.range("xattn_estimate_aggregate"): # current_index for find_blocks_chunked: Q's block offset
# Reshape: [batch, heads, q_est, k_est] -> [batch, heads, q_est, num_cpu, ratio] q_start_bsa_block = historical_k_bsa_blocks # Q starts after historical K
block_sums_coarse = block_sums_fine.view(
batch_size_bs, num_heads_bs, q_est_blocks, num_cpu_blocks, ratio
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
# Sum over Q dimension to get total attention from Q chunk to each K block with nvtx.range("xattn_find_blocks"):
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks] mask = find_blocks_chunked(
block_sums,
current_index=q_start_bsa_block, # Q's position in BSA blocks
threshold=self.threshold,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=True, # Causal for block-level mask
)
# mask shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
# Step 4: Select blocks using score + threshold (replaces mask + majority voting) # ============================================================
# This is simpler and more direct than the original mask-based approach # Step 7: Extract mask portions and record density
with nvtx.range("xattn_estimate_select"): # ============================================================
# Average scores across heads (GQA-aware: all heads contribute equally) B, H, Q_bsa, K_bsa_total = mask.shape
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
# Normalize to get attention distribution # Calculate valid Q blocks (excluding padding)
total_score = scores_per_block.sum() valid_q_bsa = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
if total_score > 0: valid_curr_k_bsa = (curr_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
score_ratio = scores_per_block / total_score
else:
# Edge case: all zeros, select all blocks
selected_block_ids = list(available_blocks)
if layer_id == 0 and available_blocks:
self._stats_total_available_blocks += len(available_blocks)
self._stats_total_selected_blocks += len(selected_block_ids)
self._stats_num_chunks += 1
return selected_block_ids
# Sort by score (descending) and select until threshold is reached # 7a: Record historical blocks density
sorted_indices = torch.argsort(score_ratio, descending=True) # IMPORTANT: For historical blocks, apply causal mask to match GPU-only density calculation!
cumsum = 0.0 # Q block i (global position = q_start_bsa_block + i) can see historical K block j
selected_indices = set() # only if j <= q_start_bsa_block + i (causal constraint)
mask_historical = mask[:, :, :valid_q_bsa, :historical_k_bsa_blocks]
for idx in sorted_indices.tolist(): if historical_k_bsa_blocks > 0:
selected_indices.add(idx) # Create causal mask for historical blocks
cumsum += score_ratio[idx].item() # Q_global[i] = q_start_bsa_block + i, K[j] = j
if cumsum >= self.threshold: # Causal: j <= Q_global[i] => j <= q_start_bsa_block + i
break q_global_indices = torch.arange(valid_q_bsa, device=mask.device) + q_start_bsa_block
k_indices = torch.arange(historical_k_bsa_blocks, device=mask.device)
# Q at position q_global_indices[i] can see K at position k_indices[j] if k_indices[j] <= q_global_indices[i]
causal_mask_historical = k_indices.unsqueeze(0) <= q_global_indices.unsqueeze(1) # [valid_q_bsa, historical_k_bsa_blocks]
# Map indices back to block IDs # Count positions within causal mask only
selected_block_ids = [available_blocks[i] for i in sorted(selected_indices)] total_historical_causal = causal_mask_historical.sum().item() * B * H
selected_historical = (mask_historical & causal_mask_historical.unsqueeze(0).unsqueeze(0)).sum().item()
if total_historical_causal > 0:
DensityObserver.record_counts(layer_id, selected_historical, total_historical_causal)
# 7b: Record current chunk density (causal, to align with GPU-only mode)
# Current chunk is the portion after historical blocks
if valid_curr_k_bsa > 0:
# Extract current chunk mask (only valid portion, not padded)
mask_current = mask[:, :, :valid_q_bsa, historical_k_bsa_blocks:historical_k_bsa_blocks + valid_curr_k_bsa]
q_dim = mask_current.shape[2]
k_dim = mask_current.shape[3]
# Create causal mask (lower triangular)
# For current chunk: Q[i] can see K[j] where j <= i (standard causal)
causal_mask = torch.tril(torch.ones(q_dim, k_dim, device=mask.device, dtype=torch.bool))
# Count positions within causal mask only
total_current_causal = causal_mask.sum().item() * B * H
selected_current = (mask_current & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
if total_current_causal > 0:
DensityObserver.record_counts(layer_id, selected_current, total_current_causal)
# Step 7.5: Save historical mask to pre-allocated buffer for compute_chunked_prefill
# Use full Q_bsa (padded) for buffer, not valid_q_bsa
mask_historical_full = mask[:, :, :, :historical_k_bsa_blocks]
if self._prefill_mask_buffer is not None:
# Only save historical portion of mask
self._prefill_mask_buffer[:, :, :Q_bsa, :historical_k_bsa_blocks].copy_(mask_historical_full)
self._current_mask_q_bsa = Q_bsa
self._current_mask_k_bsa = historical_k_bsa_blocks
# ============================================================
# Step 8: Aggregate mask to CPU block level (union of heads)
# ============================================================
# Only aggregate historical blocks (current chunk is always full attention)
num_cpu_blocks = num_historical_blocks
with nvtx.range("xattn_aggregate_mask"):
# Reshape historical mask: [B, H, Q_bsa, historical_k_bsa] -> [B, H, Q_bsa, num_cpu, bsa_per_cpu]
# Use full Q_bsa (not valid_q_bsa) for aggregation
mask_per_cpu = mask_historical_full.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu)
# Union across: bsa_per_cpu, Q_bsa, heads -> [B, num_cpu]
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1) # [B, num_cpu]
# Get selected indices
selected_indices = cpu_needed[0].nonzero().squeeze(-1).tolist()
if isinstance(selected_indices, int):
selected_indices = [selected_indices]
# Handle empty available_blocks case (first chunk)
if available_blocks:
selected_block_ids = [available_blocks[i] for i in selected_indices]
else:
selected_block_ids = []
# Always include first block (sink) and last block for safety # Always include first block (sink) and last block for safety
if available_blocks and available_blocks[0] not in selected_block_ids: if available_blocks and available_blocks[0] not in selected_block_ids:
@@ -598,6 +758,14 @@ class XAttentionBSAPolicy(SparsePolicy):
if available_blocks and available_blocks[-1] not in selected_block_ids: if available_blocks and available_blocks[-1] not in selected_block_ids:
selected_block_ids.append(available_blocks[-1]) selected_block_ids.append(available_blocks[-1])
# Record communication density (CPU block granularity) - only if there are historical blocks
if available_blocks:
DensityObserver.record_comm_density(
layer_id,
selected_cpu_blocks=len(selected_block_ids),
total_cpu_blocks=len(available_blocks),
)
# Update statistics (only for layer 0 to avoid overcounting) # Update statistics (only for layer 0 to avoid overcounting)
if layer_id == 0 and available_blocks: if layer_id == 0 and available_blocks:
self._stats_total_available_blocks += len(available_blocks) self._stats_total_available_blocks += len(available_blocks)
@@ -610,7 +778,7 @@ class XAttentionBSAPolicy(SparsePolicy):
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}") f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
# Free intermediate tensors to prevent memory leak # Free intermediate tensors to prevent memory leak
del attn_scores, block_sums_fine, block_sums_coarse, cpu_block_scores, scores_per_block del attn_scores, block_sums, mask, mask_historical_full
return selected_block_ids return selected_block_ids
@@ -636,6 +804,10 @@ class XAttentionBSAPolicy(SparsePolicy):
2. Compute attention to current chunk 2. Compute attention to current chunk
3. Merge all results 3. Merge all results
Note: The BSA-level mask is saved in self._prefill_mask_buffer by select_blocks().
Currently we use flash_attn_with_lse for computation (supports LSE merge).
TODO: Optimize to use BSA kernel with the saved mask for per-head sparse attention.
Args: Args:
q: Query tensor [seq_len, num_heads, head_dim] q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer) k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
@@ -666,6 +838,11 @@ class XAttentionBSAPolicy(SparsePolicy):
# Use the pre-selected blocks directly # Use the pre-selected blocks directly
cpu_block_table = selected_blocks cpu_block_table = selected_blocks
# Note: BSA mask is available in self._prefill_mask_buffer (saved by select_blocks)
# Mask shape: [1, num_heads, Q_bsa, K_bsa] where Q_bsa = self._current_mask_q_bsa
# Selected indices: self._selected_cpu_indices, bsa_per_cpu: self._bsa_per_cpu
# TODO: Use this mask with BSA kernel for per-head sparse attention optimization
if cpu_block_table: if cpu_block_table:
with nvtx.range("xattn_compute_historical"): with nvtx.range("xattn_compute_historical"):
load_slots = list(range(offload_engine.num_ring_slots)) load_slots = list(range(offload_engine.num_ring_slots))

View File

@@ -221,20 +221,19 @@ class Attention(nn.Module):
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill) # Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
selected_blocks = [] # Always call select_blocks even for first chunk (cpu_block_table may be empty)
if cpu_block_table: num_chunks = current_chunk_idx + 1
num_chunks = current_chunk_idx + 1 policy_ctx = PolicyContext(
policy_ctx = PolicyContext( query_chunk_idx=current_chunk_idx,
query_chunk_idx=current_chunk_idx, num_query_chunks=num_chunks,
num_query_chunks=num_chunks, layer_id=self.layer_id,
layer_id=self.layer_id, query=q, # Pass query for sparse policies that need it
query=q, # Pass query for sparse policies that need it is_prefill=True,
is_prefill=True, block_size=kvcache_manager.block_size,
block_size=kvcache_manager.block_size, total_kv_len=len(cpu_block_table) * kvcache_manager.block_size if cpu_block_table else 0,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, )
) selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx) logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path # [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, " logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
@@ -320,7 +319,7 @@ class Attention(nn.Module):
block_size=kvcache_manager.block_size, block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size, total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
) )
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx) selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks") logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
# [DEBUG] Verify execution path # [DEBUG] Verify execution path

View File

@@ -0,0 +1,308 @@
"""
DensityObserver - Sparse Attention Density 统计 Observer。
统计两种 density:
1. Compute Density (计算密度): 基于 BSA block size (128)
- density = selected_bsa_blocks / total_causal_bsa_blocks
- GPU-only 和 Offload 模式应该一致
2. Communication Density (通信密度): 基于 CPU block size (如 4096)
- comm_density = selected_cpu_blocks / total_cpu_blocks
- 仅用于 Offload 模式,由于粒度更粗,必然 >= compute density
统计位置:
- GPU-only: xattn_bsa.py compute_prefill() - 只记录 compute density
- Offload: xattn_bsa.py select_blocks() - 记录两种 density
对于 Offload 模式的 Density 计算:
- 不是简单的 avg 或 min
- 而是 sum(selected) / sum(total),正确处理不同 chunk 大小的权重
"""
from typing import List, Dict, Optional, Tuple
import torch
from nanovllm.utils.observer import Observer
class DensityObserver(Observer):
"""
Sparse Attention Density Observer。
记录每层的 density用于验证 GPU-only 和 Offload 模式的一致性。
使用方式:
DensityObserver.enable()
DensityObserver.complete_reset()
# ... run inference ...
DensityObserver.record(layer_id, mask, causal=True)
# 或者使用累积模式 (offload):
DensityObserver.record_counts(layer_id, selected, total)
# ...
DensityObserver.print_summary()
"""
_enabled: bool = False # 默认禁用
# 每层的 compute density 记录 (BSA block 粒度)
# key: layer_id, value: list of density values (每次 prefill chunk 一个)
_layer_densities: Dict[int, List[float]] = {}
# 每层的 communication density 记录 (CPU block 粒度,仅 offload 模式)
_layer_comm_densities: Dict[int, List[float]] = {}
# 累积模式: 记录 selected/total counts (用于 offload 模式)
# 这样可以在所有 chunks 完成后正确计算 density = sum(selected) / sum(total)
_layer_selected_counts: Dict[int, List[int]] = {}
_layer_total_counts: Dict[int, List[int]] = {}
# Mask shape 记录 (用于调试)
_last_q_blocks: int = 0
_last_k_blocks: int = 0
# 模式标记
_mode: str = "unknown" # "gpu_only" or "offload"
@classmethod
def set_mode(cls, mode: str) -> None:
"""设置当前模式 (gpu_only / offload)"""
cls._mode = mode
@classmethod
def record(
cls,
layer_id: int,
mask: torch.Tensor,
causal: bool = True,
) -> float:
"""
记录一层的 density (适用于 GPU-only 模式)。
Args:
layer_id: 层 ID
mask: [batch, heads, q_blocks, k_blocks] boolean tensor
causal: 是否考虑 causal mask (只计算下三角)
Returns:
density 值
"""
if not cls._enabled:
return 0.0
density = cls._compute_density(mask, causal)
# 记录
if layer_id not in cls._layer_densities:
cls._layer_densities[layer_id] = []
cls._layer_densities[layer_id].append(density)
# 记录 mask shape
cls._last_q_blocks = mask.shape[2]
cls._last_k_blocks = mask.shape[3]
return density
@classmethod
def record_counts(
cls,
layer_id: int,
selected_blocks: int,
total_blocks: int,
) -> None:
"""
记录一层的 selected/total block counts (适用于 offload 累积模式)。
使用累积计数而不是直接计算 density这样在所有 chunks 处理完后可以正确计算:
overall_density = sum(selected) / sum(total)
这比 avg(density) 更准确,因为不同 chunk 的 Q 和 K 长度不同。
Args:
layer_id: 层 ID
selected_blocks: 这个 chunk 选中的 blocks 数量
total_blocks: 这个 chunk 的 total possible blocks 数量
"""
if not cls._enabled:
return
# 初始化列表
if layer_id not in cls._layer_selected_counts:
cls._layer_selected_counts[layer_id] = []
if layer_id not in cls._layer_total_counts:
cls._layer_total_counts[layer_id] = []
# 累积记录
cls._layer_selected_counts[layer_id].append(selected_blocks)
cls._layer_total_counts[layer_id].append(total_blocks)
@classmethod
def record_comm_density(
cls,
layer_id: int,
selected_cpu_blocks: int,
total_cpu_blocks: int,
) -> float:
"""
记录一层的 communication density (CPU block 粒度)。
Args:
layer_id: 层 ID
selected_cpu_blocks: 选中的 CPU blocks 数量
total_cpu_blocks: 总 CPU blocks 数量
Returns:
communication density 值
"""
if not cls._enabled:
return 0.0
if total_cpu_blocks == 0:
return 1.0
comm_density = selected_cpu_blocks / total_cpu_blocks
# 记录
if layer_id not in cls._layer_comm_densities:
cls._layer_comm_densities[layer_id] = []
cls._layer_comm_densities[layer_id].append(comm_density)
return comm_density
@classmethod
def _compute_density(cls, mask: torch.Tensor, causal: bool) -> float:
"""计算 mask 的 density"""
batch, heads, q_blocks, k_blocks = mask.shape
if causal:
# 只计算下三角区域
causal_mask = torch.tril(
torch.ones(q_blocks, k_blocks, device=mask.device, dtype=torch.bool)
)
total_blocks = causal_mask.sum().item() * batch * heads
selected_blocks = (mask & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
else:
total_blocks = mask.numel()
selected_blocks = mask.sum().item()
if total_blocks == 0:
return 1.0
return selected_blocks / total_blocks
@classmethod
def complete_reset(cls) -> None:
"""重置所有统计"""
cls._layer_densities = {}
cls._layer_comm_densities = {}
cls._layer_selected_counts = {}
cls._layer_total_counts = {}
cls._last_q_blocks = 0
cls._last_k_blocks = 0
cls._mode = "unknown"
@classmethod
def get_per_layer_density(cls) -> Dict[int, float]:
"""
获取每层的 density。
对于累积模式 (offload): density = sum(selected) / sum(total)
对于直接记录模式 (gpu_only): density = avg(density_values)
"""
result = {}
# 优先使用累积模式 (offload)
if cls._layer_selected_counts:
for layer_id in cls._layer_selected_counts:
selected_list = cls._layer_selected_counts.get(layer_id, [])
total_list = cls._layer_total_counts.get(layer_id, [])
total_selected = sum(selected_list)
total_total = sum(total_list)
if total_total > 0:
result[layer_id] = total_selected / total_total
else:
# 直接记录模式 (gpu_only)
for layer_id, densities in cls._layer_densities.items():
if densities:
result[layer_id] = sum(densities) / len(densities)
return result
@classmethod
def get_overall_density(cls) -> float:
"""
获取所有层的总体 compute density。
对于累积模式 (offload): density = sum(all_selected) / sum(all_total)
对于直接记录模式 (gpu_only): density = avg(all_density_values)
注意: 总体 density 不是简单的 avg(per_layer_density)
而是 sum(all_selected) / sum(all_total),这样可以正确处理权重。
"""
# 优先使用累积模式 (offload)
if cls._layer_selected_counts:
total_selected = 0
total_total = 0
for layer_id in cls._layer_selected_counts:
total_selected += sum(cls._layer_selected_counts[layer_id])
total_total += sum(cls._layer_total_counts.get(layer_id, []))
if total_total > 0:
return total_selected / total_total
return 0.0
# 直接记录模式 (gpu_only)
all_densities = []
for densities in cls._layer_densities.values():
all_densities.extend(densities)
if not all_densities:
return 0.0
return sum(all_densities) / len(all_densities)
@classmethod
def get_overall_comm_density(cls) -> float:
"""获取所有层的平均 communication density"""
all_densities = []
for densities in cls._layer_comm_densities.values():
all_densities.extend(densities)
if not all_densities:
return 0.0
return sum(all_densities) / len(all_densities)
@classmethod
def get_summary(cls) -> dict:
"""返回统计摘要"""
per_layer = cls.get_per_layer_density()
return {
"mode": cls._mode,
"overall_density": cls.get_overall_density(),
"per_layer_density": per_layer,
"num_layers": len(per_layer),
"last_mask_shape": {
"q_blocks": cls._last_q_blocks,
"k_blocks": cls._last_k_blocks,
},
}
@classmethod
def get_min_density(cls) -> Tuple[int, float]:
"""获取最低 density 的层和值"""
per_layer = cls.get_per_layer_density()
if not per_layer:
return -1, 0.0
min_layer = min(per_layer, key=per_layer.get)
return min_layer, per_layer[min_layer]
@classmethod
def print_summary(cls) -> None:
"""打印人类可读的摘要"""
per_layer = cls.get_per_layer_density()
overall = cls.get_overall_density()
min_layer, min_density = cls.get_min_density()
overall_comm = cls.get_overall_comm_density()
print(f"[DensityObserver] Mode: {cls._mode}")
print(f" Compute density: {overall:.4f} (min: {min_density:.4f} @ layer {min_layer})")
if overall_comm > 0:
print(f" Comm density: {overall_comm:.4f}")
print(f" Num layers: {len(per_layer)}")
# 输出 layer 0 的 density 用于对比
if 0 in per_layer:
print(f" Layer 0 density: {per_layer[0]:.6f}")

View File

@@ -41,6 +41,7 @@ from pathlib import Path
from typing import List, Dict, Tuple, Optional from typing import List, Dict, Tuple, Optional
from nanovllm import LLM, SamplingParams from nanovllm import LLM, SamplingParams
from nanovllm.utils.density_observer import DensityObserver
# ============================================================ # ============================================================
@@ -381,6 +382,13 @@ def run_ruler_benchmark(
print(f"Fresh LLM mode: {fresh_llm}") print(f"Fresh LLM mode: {fresh_llm}")
print(f"{'='*60}") print(f"{'='*60}")
# Enable DensityObserver for XAttention BSA
if sparse_policy and sparse_policy.upper() == "XATTN_BSA":
DensityObserver.enable()
DensityObserver.complete_reset()
if not json_output:
print("[DensityObserver] Enabled for XAttention BSA")
# LLM initialization kwargs # LLM initialization kwargs
llm_kwargs = { llm_kwargs = {
"max_model_len": max_model_len, "max_model_len": max_model_len,
@@ -471,6 +479,14 @@ def run_ruler_benchmark(
print(f"{'-'*54}") print(f"{'-'*54}")
print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}") print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}")
print(f"\nTime: {total_time:.1f}s") print(f"\nTime: {total_time:.1f}s")
# Print DensityObserver summary if enabled
if sparse_policy and sparse_policy.upper() == "XATTN_BSA" and DensityObserver.is_enabled():
print(f"\n{'='*60}")
print("Density Statistics (XAttention BSA)")
print(f"{'='*60}")
DensityObserver.print_summary()
print(f"{'='*60}\n") print(f"{'='*60}\n")
results = { results = {

View File

@@ -41,9 +41,9 @@ K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
for i in range(q_len): for i in range(q_len):
if i % 2 == 0: if i % 2 == 0:
Q[0, 0, i, :] = 1 Q[0, 0, i, :] = 1 * (i // stride + 1)
else: else:
Q[0, 0, i, :] = 2 Q[0, 0, i, :] = 2 * (i // stride + 1)
for i in range(kv_len): for i in range(kv_len):
if i % 2 == 0: if i % 2 == 0:
@@ -74,8 +74,11 @@ for k_chunk_idx in range(num_k_chunks):
Q, K_chunk, stride, Q, K_chunk, stride,
chunk_start=0, chunk_start=0,
chunk_end=q_reshaped_len, chunk_end=q_reshaped_len,
is_causal=False is_causal=True
) )
__import__('pdb').set_trace()
attn_scores_list.append(attn_chunk) attn_scores_list.append(attn_chunk)
# 拼接所有 K chunks 的结果 # 拼接所有 K chunks 的结果