From f6ac4ccddea8034a41b947f52f95c3acf1ed5909 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Fri, 30 Jan 2026 16:26:56 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20add=20DensityObserver=20for?= =?UTF-8?q?=20XAttention=20sparse=20attention=20density=20tracking?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add DensityObserver class to track per-layer density statistics - Integrate DensityObserver into compute_prefill for GPU-only mode - Fix stride parameter not being passed to xattn_estimate - Add density statistics output to test_ruler.py for XATTN_BSA - Add comprehensive density benchmark documentation Key changes: - nanovllm/utils/density_observer.py: New Observer for density tracking - xattn_bsa.py: Add stride param to xattn_estimate, integrate DensityObserver - test_ruler.py: Enable DensityObserver and print summary for XATTN_BSA - docs/xattn_density_benchmark.md: Benchmark results for 4K-32K contexts Co-Authored-By: Claude Opus 4.5 --- CLAUDE.md | 1 + docs/xattn_density_benchmark.md | 195 +++++++++++++++++++++++++++ nanovllm/kvcache/sparse/xattn_bsa.py | 15 ++- nanovllm/utils/density_observer.py | 167 +++++++++++++++++++++++ tests/test_ruler.py | 16 +++ 5 files changed, 387 insertions(+), 7 deletions(-) create mode 100644 docs/xattn_density_benchmark.md create mode 100644 nanovllm/utils/density_observer.py diff --git a/CLAUDE.md b/CLAUDE.md index 9aa3284..e643971 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -18,6 +18,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L | [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) | | [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 | | [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 | +| [`docs/xattn_density_benchmark.md`](docs/xattn_density_benchmark.md) | 📊 XAttention Density Benchmark: 4K-32K context、stride 参数、per-layer density 分析 | | [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 | | [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling | | [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) | diff --git a/docs/xattn_density_benchmark.md b/docs/xattn_density_benchmark.md new file mode 100644 index 0000000..a6731da --- /dev/null +++ b/docs/xattn_density_benchmark.md @@ -0,0 +1,195 @@ +# XAttention Density Benchmark + +GPU-only 模式下 XAttention Block Sparse Attention 的 density 测试结果。 + +## 测试配置 + +| 参数 | 值 | 说明 | +|------|-----|------| +| Model | Llama-3.1-8B-Instruct | 32 layers, 32 heads, 8 KV heads | +| Block Size | 128 tokens | BSA kernel 固定要求 | +| Threshold | 0.9 / 0.95 | 累积注意力阈值 | +| Stride | 4 / 8 / 16 | Q/K 下采样步长 | +| Dataset | RULER niah_single_1 | Sample 0 | +| Mode | GPU-only | 无 CPU offload | + +## Density 定义 + +```python +# Density = selected_blocks / total_causal_blocks +# 在 causal attention 下,只计算下三角区域的 blocks +# Overall density = 所有层的平均值 + +def compute_density(mask, causal=True): + """ + mask: [batch, heads, q_blocks, k_blocks] boolean tensor + """ + if causal: + causal_mask = torch.tril(torch.ones(q_blocks, k_blocks)) + total = causal_mask.sum() * batch * heads + selected = (mask & causal_mask).sum() + return selected / total +``` + +## 测试结果 + +### threshold=0.9 + +#### Overall Density (平均) + +| Context | stride=4 | stride=8 | stride=16 | +|---------|----------|----------|-----------| +| **4K** | 0.5220 (52.2%) | 0.5292 (52.9%) | 0.5430 (54.3%) | +| **8K** | 0.5152 (51.5%) | 0.5252 (52.5%) | 0.5396 (54.0%) | +| **16K** | 0.4682 (46.8%) | 0.4775 (47.8%) | 0.4888 (48.9%) | +| **32K** | 0.3700 (37.0%) | 0.4012 (40.1%) | 0.4196 (42.0%) | + +#### Min Density (per layer) + +| Context | stride=4 | stride=8 | stride=16 | +|---------|----------|----------|-----------| +| **4K** | 0.2805 (Layer 3) | 0.3132 (Layer 3) | 0.3376 (Layer 5) | +| **8K** | 0.2886 (Layer 5) | 0.2725 (Layer 5) | 0.2995 (Layer 5) | +| **16K** | 0.2247 (Layer 5) | 0.2349 (Layer 5) | 0.2451 (Layer 5) | +| **32K** | 0.1799 (Layer 5) | 0.1846 (Layer 5) | 0.1964 (Layer 5) | + +### threshold=0.95 + +#### Overall Density (平均) + +| Context | stride=4 | stride=8 | stride=16 | +|---------|----------|----------|-----------| +| **4K** | 0.6561 (65.6%) | 0.6699 (67.0%) | 0.6815 (68.2%) | +| **8K** | 0.6462 (64.6%) | 0.6584 (65.8%) | 0.6732 (67.3%) | +| **16K** | 0.6004 (60.0%) | 0.6114 (61.1%) | 0.6193 (61.9%) | +| **32K** | 0.4894 (48.9%) | 0.5203 (52.0%) | 0.5385 (53.9%) | + +#### Min Density (per layer) + +| Context | stride=4 | stride=8 | stride=16 | +|---------|----------|----------|-----------| +| **4K** | 0.3972 (Layer 3) | 0.4348 (Layer 5) | 0.4517 (Layer 4) | +| **8K** | 0.4004 (Layer 5) | 0.3906 (Layer 5) | 0.4239 (Layer 5) | +| **16K** | 0.3331 (Layer 5) | 0.3453 (Layer 5) | 0.3589 (Layer 5) | +| **32K** | 0.2656 (Layer 5) | 0.2784 (Layer 5) | 0.2917 (Layer 5) | + +### threshold 对比 (stride=8) + +| Context | threshold=0.9 | threshold=0.95 | 差异 | +|---------|---------------|----------------|------| +| **4K** | 0.5292 (52.9%) | 0.6699 (67.0%) | -14.1% | +| **8K** | 0.5252 (52.5%) | 0.6584 (65.8%) | -13.3% | +| **16K** | 0.4775 (47.8%) | 0.6114 (61.1%) | -13.4% | +| **32K** | 0.4012 (40.1%) | 0.5203 (52.0%) | -11.9% | + +## 关键发现 + +### 1. Context Length 影响最大 + +Density 随 context length 显著下降(threshold=0.9, stride=8): +- 4K: 52.9% density +- 8K: 52.5% density +- 16K: 47.8% density +- 32K: 40.1% density + +**结论**: 长序列有更多稀疏化机会,XAttention 的优势在长序列上更明显。 + +### 2. Threshold 影响显著 + +threshold=0.9 比 0.95 的 density 低约 12-14%: +- 0.9 更激进,选择更少的 blocks +- 0.95 更保守,保留更多 blocks +- 两者准确性都不受影响(RULER NIAH 全部 PASS) + +### 3. Stride 影响较小 + +同一 context 下,不同 stride 的 density 差异约 2-5%: +- stride 越大 → density 略高(采样越粗糙,选择更保守) +- stride=4 最激进,stride=16 最保守 + +### 4. Min Density 集中在中间层 + +- 大多数情况下 min density 出现在 Layer 5 +- 中间层的稀疏性最高,首尾层相对密集 +- 这符合 Transformer 注意力模式的一般规律 + +### 5. 最佳稀疏化配置 + +32K + stride=4 + threshold=0.9 达到最低 density: +- Overall: **37.0%** (节省 63% 计算) +- Min: **18.0%** (Layer 5) + +### 6. 准确性稳定 + +所有配置下 RULER NIAH 测试都 PASS (score=1.0),说明: +- threshold=0.9 和 0.95 都足够保守,不损失准确性 +- 不同 stride 不影响最终结果 + +## 推荐配置 + +| 场景 | threshold | stride | 说明 | +|------|-----------|--------|------| +| 精度优先 | 0.95 | 8 | 保守配置,density ~52-67% | +| 平衡 | 0.9 | 8 | 默认配置,density ~40-53% | +| 性能优先 | 0.9 | 4 | 激进配置,density ~37-52% | + +## 测试命令 + +```bash +# 基本测试 +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \ + python tests/test_ruler.py \ + --data-dir tests/data/ruler_32k \ + --datasets niah_single_1 \ + --sample-indices 0 \ + --max-model-len 33792 \ + --sparse-policy XATTN_BSA \ + --sparse-threshold 0.9 \ + --sparse-stride 8 \ + --gpu-utilization 0.85 + +# 参数说明 +# --sparse-policy XATTN_BSA 启用 XAttention Block Sparse Attention +# --sparse-threshold 0.9 累积注意力阈值 (0.9-0.99) +# --sparse-stride 8 Q/K 下采样步长 (4/8/16) +``` + +## DensityObserver 使用 + +```python +from nanovllm.utils.density_observer import DensityObserver + +# 启用并重置 +DensityObserver.enable() +DensityObserver.complete_reset() + +# ... 运行 inference (compute_prefill 自动记录) ... + +# 获取结果 +summary = DensityObserver.get_summary() +# { +# "mode": "gpu_only", +# "overall_density": 0.40, # 所有层的平均值 +# "per_layer_density": {0: 0.55, 1: 0.45, ...}, +# "num_layers": 32 +# } + +# 获取最低 density +min_layer, min_density = DensityObserver.get_min_density() + +# 打印摘要 +DensityObserver.print_summary() +# [DensityObserver] Mode: gpu_only +# Overall density: 0.4012 +# Min density: 0.1846 (layer 5) +# Num layers: 32 +``` + +## 相关文件 + +| 文件 | 说明 | +|------|------| +| `nanovllm/kvcache/sparse/xattn_bsa.py` | XAttention BSA Policy 实现 | +| `nanovllm/utils/density_observer.py` | Density 统计 Observer | +| `nanovllm/ops/xattn.py` | xattn_estimate 核心算法 | +| `tests/test_ruler.py` | RULER benchmark 测试脚本 | diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index bf3978a..04098c0 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -17,6 +17,7 @@ import torch.cuda.nvtx as nvtx from typing import List, Tuple, TYPE_CHECKING from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext +from nanovllm.utils.density_observer import DensityObserver if TYPE_CHECKING: from nanovllm.kvcache.offload_engine import OffloadEngine @@ -258,6 +259,10 @@ class XAttentionBSAPolicy(SparsePolicy): from nanovllm.ops.xattn import xattn_estimate + # Set DensityObserver mode on first layer + if layer_id == 0: + DensityObserver.set_mode("gpu_only") + # Get dimensions total_q, num_heads, head_dim = q.shape total_kv, num_kv_heads, _ = k.shape @@ -315,6 +320,7 @@ class XAttentionBSAPolicy(SparsePolicy): Q, K_exp, chunk_size=self.chunk_size, block_size=self.BSA_BLOCK_SIZE, + stride=self.stride, threshold=self.threshold, use_triton=self.use_triton, causal=True, @@ -360,13 +366,8 @@ class XAttentionBSAPolicy(SparsePolicy): is_causal=True, ) - # Update statistics (layer 0 only to avoid overcounting) - if layer_id == 0: - selected_blocks = mask_trimmed.sum().item() - total_blocks = q_block_num * k_block_num * num_heads - density = selected_blocks / total_blocks if total_blocks > 0 else 1.0 - logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, " - f"k_blocks={k_block_num}, density={density:.1%}") + # Record density for all layers via DensityObserver + DensityObserver.record(layer_id, mask_trimmed, causal=True) return output diff --git a/nanovllm/utils/density_observer.py b/nanovllm/utils/density_observer.py new file mode 100644 index 0000000..3537980 --- /dev/null +++ b/nanovllm/utils/density_observer.py @@ -0,0 +1,167 @@ +""" +DensityObserver - Sparse Attention Density 统计 Observer。 + +统计每层的 sparse attention density: +- density = selected_blocks / total_causal_blocks +- 在 causal attention 下,只计算下三角区域 + +统计位置: +- GPU-only: xattn_bsa.py compute_prefill() +- Offload: xattn_bsa.py select_blocks() +""" + +from typing import List, Dict, Optional, Tuple +import torch +from nanovllm.utils.observer import Observer + + +class DensityObserver(Observer): + """ + Sparse Attention Density Observer。 + + 记录每层的 density,用于验证 GPU-only 和 Offload 模式的一致性。 + + 使用方式: + DensityObserver.enable() + DensityObserver.complete_reset() + # ... run inference ... + DensityObserver.record(layer_id, mask, causal=True) + # ... + DensityObserver.print_summary() + """ + + _enabled: bool = False # 默认禁用 + + # 每层的 density 记录 + # key: layer_id, value: list of density values (每次 prefill chunk 一个) + _layer_densities: Dict[int, List[float]] = {} + + # Mask shape 记录 (用于调试) + _last_q_blocks: int = 0 + _last_k_blocks: int = 0 + + # 模式标记 + _mode: str = "unknown" # "gpu_only" or "offload" + + @classmethod + def set_mode(cls, mode: str) -> None: + """设置当前模式 (gpu_only / offload)""" + cls._mode = mode + + @classmethod + def record( + cls, + layer_id: int, + mask: torch.Tensor, + causal: bool = True, + ) -> float: + """ + 记录一层的 density。 + + Args: + layer_id: 层 ID + mask: [batch, heads, q_blocks, k_blocks] boolean tensor + causal: 是否考虑 causal mask (只计算下三角) + + Returns: + density 值 + """ + if not cls._enabled: + return 0.0 + + density = cls._compute_density(mask, causal) + + # 记录 + if layer_id not in cls._layer_densities: + cls._layer_densities[layer_id] = [] + cls._layer_densities[layer_id].append(density) + + # 记录 mask shape + cls._last_q_blocks = mask.shape[2] + cls._last_k_blocks = mask.shape[3] + + return density + + @classmethod + def _compute_density(cls, mask: torch.Tensor, causal: bool) -> float: + """计算 mask 的 density""" + batch, heads, q_blocks, k_blocks = mask.shape + + if causal: + # 只计算下三角区域 + causal_mask = torch.tril( + torch.ones(q_blocks, k_blocks, device=mask.device, dtype=torch.bool) + ) + total_blocks = causal_mask.sum().item() * batch * heads + selected_blocks = (mask & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() + else: + total_blocks = mask.numel() + selected_blocks = mask.sum().item() + + if total_blocks == 0: + return 1.0 + + return selected_blocks / total_blocks + + @classmethod + def complete_reset(cls) -> None: + """重置所有统计""" + cls._layer_densities = {} + cls._last_q_blocks = 0 + cls._last_k_blocks = 0 + cls._mode = "unknown" + + @classmethod + def get_per_layer_density(cls) -> Dict[int, float]: + """获取每层的平均 density""" + result = {} + for layer_id, densities in cls._layer_densities.items(): + if densities: + result[layer_id] = sum(densities) / len(densities) + return result + + @classmethod + def get_overall_density(cls) -> float: + """获取所有层的平均 density""" + all_densities = [] + for densities in cls._layer_densities.values(): + all_densities.extend(densities) + if not all_densities: + return 0.0 + return sum(all_densities) / len(all_densities) + + @classmethod + def get_summary(cls) -> dict: + """返回统计摘要""" + per_layer = cls.get_per_layer_density() + return { + "mode": cls._mode, + "overall_density": cls.get_overall_density(), + "per_layer_density": per_layer, + "num_layers": len(per_layer), + "last_mask_shape": { + "q_blocks": cls._last_q_blocks, + "k_blocks": cls._last_k_blocks, + }, + } + + @classmethod + def get_min_density(cls) -> Tuple[int, float]: + """获取最低 density 的层和值""" + per_layer = cls.get_per_layer_density() + if not per_layer: + return -1, 0.0 + min_layer = min(per_layer, key=per_layer.get) + return min_layer, per_layer[min_layer] + + @classmethod + def print_summary(cls) -> None: + """打印人类可读的摘要""" + per_layer = cls.get_per_layer_density() + overall = cls.get_overall_density() + min_layer, min_density = cls.get_min_density() + + print(f"[DensityObserver] Mode: {cls._mode}") + print(f" Overall density: {overall:.4f}") + print(f" Min density: {min_density:.4f} (layer {min_layer})") + print(f" Num layers: {len(per_layer)}") diff --git a/tests/test_ruler.py b/tests/test_ruler.py index 1451f6f..df95ff8 100644 --- a/tests/test_ruler.py +++ b/tests/test_ruler.py @@ -41,6 +41,7 @@ from pathlib import Path from typing import List, Dict, Tuple, Optional from nanovllm import LLM, SamplingParams +from nanovllm.utils.density_observer import DensityObserver # ============================================================ @@ -381,6 +382,13 @@ def run_ruler_benchmark( print(f"Fresh LLM mode: {fresh_llm}") print(f"{'='*60}") + # Enable DensityObserver for XAttention BSA + if sparse_policy and sparse_policy.upper() == "XATTN_BSA": + DensityObserver.enable() + DensityObserver.complete_reset() + if not json_output: + print("[DensityObserver] Enabled for XAttention BSA") + # LLM initialization kwargs llm_kwargs = { "max_model_len": max_model_len, @@ -471,6 +479,14 @@ def run_ruler_benchmark( print(f"{'-'*54}") print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}") print(f"\nTime: {total_time:.1f}s") + + # Print DensityObserver summary if enabled + if sparse_policy and sparse_policy.upper() == "XATTN_BSA" and DensityObserver.is_enabled(): + print(f"\n{'='*60}") + print("Density Statistics (XAttention BSA)") + print(f"{'='*60}") + DensityObserver.print_summary() + print(f"{'='*60}\n") results = {