✨ feat: add DensityObserver for XAttention sparse attention density tracking
- Add DensityObserver class to track per-layer density statistics - Integrate DensityObserver into compute_prefill for GPU-only mode - Fix stride parameter not being passed to xattn_estimate - Add density statistics output to test_ruler.py for XATTN_BSA - Add comprehensive density benchmark documentation Key changes: - nanovllm/utils/density_observer.py: New Observer for density tracking - xattn_bsa.py: Add stride param to xattn_estimate, integrate DensityObserver - test_ruler.py: Enable DensityObserver and print summary for XATTN_BSA - docs/xattn_density_benchmark.md: Benchmark results for 4K-32K contexts Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
167
nanovllm/utils/density_observer.py
Normal file
167
nanovllm/utils/density_observer.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
DensityObserver - Sparse Attention Density 统计 Observer。
|
||||
|
||||
统计每层的 sparse attention density:
|
||||
- density = selected_blocks / total_causal_blocks
|
||||
- 在 causal attention 下,只计算下三角区域
|
||||
|
||||
统计位置:
|
||||
- GPU-only: xattn_bsa.py compute_prefill()
|
||||
- Offload: xattn_bsa.py select_blocks()
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
import torch
|
||||
from nanovllm.utils.observer import Observer
|
||||
|
||||
|
||||
class DensityObserver(Observer):
|
||||
"""
|
||||
Sparse Attention Density Observer。
|
||||
|
||||
记录每层的 density,用于验证 GPU-only 和 Offload 模式的一致性。
|
||||
|
||||
使用方式:
|
||||
DensityObserver.enable()
|
||||
DensityObserver.complete_reset()
|
||||
# ... run inference ...
|
||||
DensityObserver.record(layer_id, mask, causal=True)
|
||||
# ...
|
||||
DensityObserver.print_summary()
|
||||
"""
|
||||
|
||||
_enabled: bool = False # 默认禁用
|
||||
|
||||
# 每层的 density 记录
|
||||
# key: layer_id, value: list of density values (每次 prefill chunk 一个)
|
||||
_layer_densities: Dict[int, List[float]] = {}
|
||||
|
||||
# Mask shape 记录 (用于调试)
|
||||
_last_q_blocks: int = 0
|
||||
_last_k_blocks: int = 0
|
||||
|
||||
# 模式标记
|
||||
_mode: str = "unknown" # "gpu_only" or "offload"
|
||||
|
||||
@classmethod
|
||||
def set_mode(cls, mode: str) -> None:
|
||||
"""设置当前模式 (gpu_only / offload)"""
|
||||
cls._mode = mode
|
||||
|
||||
@classmethod
|
||||
def record(
|
||||
cls,
|
||||
layer_id: int,
|
||||
mask: torch.Tensor,
|
||||
causal: bool = True,
|
||||
) -> float:
|
||||
"""
|
||||
记录一层的 density。
|
||||
|
||||
Args:
|
||||
layer_id: 层 ID
|
||||
mask: [batch, heads, q_blocks, k_blocks] boolean tensor
|
||||
causal: 是否考虑 causal mask (只计算下三角)
|
||||
|
||||
Returns:
|
||||
density 值
|
||||
"""
|
||||
if not cls._enabled:
|
||||
return 0.0
|
||||
|
||||
density = cls._compute_density(mask, causal)
|
||||
|
||||
# 记录
|
||||
if layer_id not in cls._layer_densities:
|
||||
cls._layer_densities[layer_id] = []
|
||||
cls._layer_densities[layer_id].append(density)
|
||||
|
||||
# 记录 mask shape
|
||||
cls._last_q_blocks = mask.shape[2]
|
||||
cls._last_k_blocks = mask.shape[3]
|
||||
|
||||
return density
|
||||
|
||||
@classmethod
|
||||
def _compute_density(cls, mask: torch.Tensor, causal: bool) -> float:
|
||||
"""计算 mask 的 density"""
|
||||
batch, heads, q_blocks, k_blocks = mask.shape
|
||||
|
||||
if causal:
|
||||
# 只计算下三角区域
|
||||
causal_mask = torch.tril(
|
||||
torch.ones(q_blocks, k_blocks, device=mask.device, dtype=torch.bool)
|
||||
)
|
||||
total_blocks = causal_mask.sum().item() * batch * heads
|
||||
selected_blocks = (mask & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
else:
|
||||
total_blocks = mask.numel()
|
||||
selected_blocks = mask.sum().item()
|
||||
|
||||
if total_blocks == 0:
|
||||
return 1.0
|
||||
|
||||
return selected_blocks / total_blocks
|
||||
|
||||
@classmethod
|
||||
def complete_reset(cls) -> None:
|
||||
"""重置所有统计"""
|
||||
cls._layer_densities = {}
|
||||
cls._last_q_blocks = 0
|
||||
cls._last_k_blocks = 0
|
||||
cls._mode = "unknown"
|
||||
|
||||
@classmethod
|
||||
def get_per_layer_density(cls) -> Dict[int, float]:
|
||||
"""获取每层的平均 density"""
|
||||
result = {}
|
||||
for layer_id, densities in cls._layer_densities.items():
|
||||
if densities:
|
||||
result[layer_id] = sum(densities) / len(densities)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_overall_density(cls) -> float:
|
||||
"""获取所有层的平均 density"""
|
||||
all_densities = []
|
||||
for densities in cls._layer_densities.values():
|
||||
all_densities.extend(densities)
|
||||
if not all_densities:
|
||||
return 0.0
|
||||
return sum(all_densities) / len(all_densities)
|
||||
|
||||
@classmethod
|
||||
def get_summary(cls) -> dict:
|
||||
"""返回统计摘要"""
|
||||
per_layer = cls.get_per_layer_density()
|
||||
return {
|
||||
"mode": cls._mode,
|
||||
"overall_density": cls.get_overall_density(),
|
||||
"per_layer_density": per_layer,
|
||||
"num_layers": len(per_layer),
|
||||
"last_mask_shape": {
|
||||
"q_blocks": cls._last_q_blocks,
|
||||
"k_blocks": cls._last_k_blocks,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_min_density(cls) -> Tuple[int, float]:
|
||||
"""获取最低 density 的层和值"""
|
||||
per_layer = cls.get_per_layer_density()
|
||||
if not per_layer:
|
||||
return -1, 0.0
|
||||
min_layer = min(per_layer, key=per_layer.get)
|
||||
return min_layer, per_layer[min_layer]
|
||||
|
||||
@classmethod
|
||||
def print_summary(cls) -> None:
|
||||
"""打印人类可读的摘要"""
|
||||
per_layer = cls.get_per_layer_density()
|
||||
overall = cls.get_overall_density()
|
||||
min_layer, min_density = cls.get_min_density()
|
||||
|
||||
print(f"[DensityObserver] Mode: {cls._mode}")
|
||||
print(f" Overall density: {overall:.4f}")
|
||||
print(f" Min density: {min_density:.4f} (layer {min_layer})")
|
||||
print(f" Num layers: {len(per_layer)}")
|
||||
Reference in New Issue
Block a user