Compare commits
2 Commits
4484a1482c
...
2e96d1d97d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2e96d1d97d | ||
|
|
f6ac4ccdde |
@@ -18,6 +18,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
| [`docs/xattn_kernels_guide.md`](docs/xattn_kernels_guide.md) | XAttention Triton kernels: flat_group_gemm (反对角线求和)、softmax_fuse_block_sum (block 聚合) |
|
||||
| [`docs/xattn_chunked_prefill.md`](docs/xattn_chunked_prefill.md) | XAttention chunked prefill: API、使用方式、一致性要求 |
|
||||
| [`docs/xattn_bsa_policy_design.md`](docs/xattn_bsa_policy_design.md) | XAttention BSA Policy: 算法设计、性能基准(128K)、内存管理、density 统计 |
|
||||
| [`docs/xattn_density_benchmark.md`](docs/xattn_density_benchmark.md) | 📊 XAttention Density Benchmark: 4K-32K context、stride 参数、per-layer density 分析 |
|
||||
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
|
||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
||||
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
||||
@@ -36,6 +37,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
| [`docs/estimate_block_size_performance.md`](docs/estimate_block_size_performance.md) | 🔥 PERF: estimate 阶段 block_size 性能分析,softmax_fuse_block_sum 最优点 (512-1024),当前 4096 慢 15x |
|
||||
| [`docs/long_context_models_1m.md`](docs/long_context_models_1m.md) | 📚 REF: 1M+ 上下文长度模型列表 (Qwen/GLM/InternLM/Llama/VL),≤10B 推荐模型 |
|
||||
| [`docs/new_model_integration_guide.md`](docs/new_model_integration_guide.md) | 🔧 GUIDE: 新模型整合指南 - 配置映射、RoPE变体、EOS处理、权重转换、验证清单 |
|
||||
| [`docs/xattn_density_alignment_analysis.md`](docs/xattn_density_alignment_analysis.md) | 📊 ANALYSIS: GPU-only vs Offload 模式 density 对齐分析,chunked softmax 边界效应,5-7% 差异根因 |
|
||||
|
||||
## Rules Index
|
||||
|
||||
|
||||
195
docs/xattn_density_benchmark.md
Normal file
195
docs/xattn_density_benchmark.md
Normal file
@@ -0,0 +1,195 @@
|
||||
# XAttention Density Benchmark
|
||||
|
||||
GPU-only 模式下 XAttention Block Sparse Attention 的 density 测试结果。
|
||||
|
||||
## 测试配置
|
||||
|
||||
| 参数 | 值 | 说明 |
|
||||
|------|-----|------|
|
||||
| Model | Llama-3.1-8B-Instruct | 32 layers, 32 heads, 8 KV heads |
|
||||
| Block Size | 128 tokens | BSA kernel 固定要求 |
|
||||
| Threshold | 0.9 / 0.95 | 累积注意力阈值 |
|
||||
| Stride | 4 / 8 / 16 | Q/K 下采样步长 |
|
||||
| Dataset | RULER niah_single_1 | Sample 0 |
|
||||
| Mode | GPU-only | 无 CPU offload |
|
||||
|
||||
## Density 定义
|
||||
|
||||
```python
|
||||
# Density = selected_blocks / total_causal_blocks
|
||||
# 在 causal attention 下,只计算下三角区域的 blocks
|
||||
# Overall density = 所有层的平均值
|
||||
|
||||
def compute_density(mask, causal=True):
|
||||
"""
|
||||
mask: [batch, heads, q_blocks, k_blocks] boolean tensor
|
||||
"""
|
||||
if causal:
|
||||
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks))
|
||||
total = causal_mask.sum() * batch * heads
|
||||
selected = (mask & causal_mask).sum()
|
||||
return selected / total
|
||||
```
|
||||
|
||||
## 测试结果
|
||||
|
||||
### threshold=0.9
|
||||
|
||||
#### Overall Density (平均)
|
||||
|
||||
| Context | stride=4 | stride=8 | stride=16 |
|
||||
|---------|----------|----------|-----------|
|
||||
| **4K** | 0.5220 (52.2%) | 0.5292 (52.9%) | 0.5430 (54.3%) |
|
||||
| **8K** | 0.5152 (51.5%) | 0.5252 (52.5%) | 0.5396 (54.0%) |
|
||||
| **16K** | 0.4682 (46.8%) | 0.4775 (47.8%) | 0.4888 (48.9%) |
|
||||
| **32K** | 0.3700 (37.0%) | 0.4012 (40.1%) | 0.4196 (42.0%) |
|
||||
|
||||
#### Min Density (per layer)
|
||||
|
||||
| Context | stride=4 | stride=8 | stride=16 |
|
||||
|---------|----------|----------|-----------|
|
||||
| **4K** | 0.2805 (Layer 3) | 0.3132 (Layer 3) | 0.3376 (Layer 5) |
|
||||
| **8K** | 0.2886 (Layer 5) | 0.2725 (Layer 5) | 0.2995 (Layer 5) |
|
||||
| **16K** | 0.2247 (Layer 5) | 0.2349 (Layer 5) | 0.2451 (Layer 5) |
|
||||
| **32K** | 0.1799 (Layer 5) | 0.1846 (Layer 5) | 0.1964 (Layer 5) |
|
||||
|
||||
### threshold=0.95
|
||||
|
||||
#### Overall Density (平均)
|
||||
|
||||
| Context | stride=4 | stride=8 | stride=16 |
|
||||
|---------|----------|----------|-----------|
|
||||
| **4K** | 0.6561 (65.6%) | 0.6699 (67.0%) | 0.6815 (68.2%) |
|
||||
| **8K** | 0.6462 (64.6%) | 0.6584 (65.8%) | 0.6732 (67.3%) |
|
||||
| **16K** | 0.6004 (60.0%) | 0.6114 (61.1%) | 0.6193 (61.9%) |
|
||||
| **32K** | 0.4894 (48.9%) | 0.5203 (52.0%) | 0.5385 (53.9%) |
|
||||
|
||||
#### Min Density (per layer)
|
||||
|
||||
| Context | stride=4 | stride=8 | stride=16 |
|
||||
|---------|----------|----------|-----------|
|
||||
| **4K** | 0.3972 (Layer 3) | 0.4348 (Layer 5) | 0.4517 (Layer 4) |
|
||||
| **8K** | 0.4004 (Layer 5) | 0.3906 (Layer 5) | 0.4239 (Layer 5) |
|
||||
| **16K** | 0.3331 (Layer 5) | 0.3453 (Layer 5) | 0.3589 (Layer 5) |
|
||||
| **32K** | 0.2656 (Layer 5) | 0.2784 (Layer 5) | 0.2917 (Layer 5) |
|
||||
|
||||
### threshold 对比 (stride=8)
|
||||
|
||||
| Context | threshold=0.9 | threshold=0.95 | 差异 |
|
||||
|---------|---------------|----------------|------|
|
||||
| **4K** | 0.5292 (52.9%) | 0.6699 (67.0%) | -14.1% |
|
||||
| **8K** | 0.5252 (52.5%) | 0.6584 (65.8%) | -13.3% |
|
||||
| **16K** | 0.4775 (47.8%) | 0.6114 (61.1%) | -13.4% |
|
||||
| **32K** | 0.4012 (40.1%) | 0.5203 (52.0%) | -11.9% |
|
||||
|
||||
## 关键发现
|
||||
|
||||
### 1. Context Length 影响最大
|
||||
|
||||
Density 随 context length 显著下降(threshold=0.9, stride=8):
|
||||
- 4K: 52.9% density
|
||||
- 8K: 52.5% density
|
||||
- 16K: 47.8% density
|
||||
- 32K: 40.1% density
|
||||
|
||||
**结论**: 长序列有更多稀疏化机会,XAttention 的优势在长序列上更明显。
|
||||
|
||||
### 2. Threshold 影响显著
|
||||
|
||||
threshold=0.9 比 0.95 的 density 低约 12-14%:
|
||||
- 0.9 更激进,选择更少的 blocks
|
||||
- 0.95 更保守,保留更多 blocks
|
||||
- 两者准确性都不受影响(RULER NIAH 全部 PASS)
|
||||
|
||||
### 3. Stride 影响较小
|
||||
|
||||
同一 context 下,不同 stride 的 density 差异约 2-5%:
|
||||
- stride 越大 → density 略高(采样越粗糙,选择更保守)
|
||||
- stride=4 最激进,stride=16 最保守
|
||||
|
||||
### 4. Min Density 集中在中间层
|
||||
|
||||
- 大多数情况下 min density 出现在 Layer 5
|
||||
- 中间层的稀疏性最高,首尾层相对密集
|
||||
- 这符合 Transformer 注意力模式的一般规律
|
||||
|
||||
### 5. 最佳稀疏化配置
|
||||
|
||||
32K + stride=4 + threshold=0.9 达到最低 density:
|
||||
- Overall: **37.0%** (节省 63% 计算)
|
||||
- Min: **18.0%** (Layer 5)
|
||||
|
||||
### 6. 准确性稳定
|
||||
|
||||
所有配置下 RULER NIAH 测试都 PASS (score=1.0),说明:
|
||||
- threshold=0.9 和 0.95 都足够保守,不损失准确性
|
||||
- 不同 stride 不影响最终结果
|
||||
|
||||
## 推荐配置
|
||||
|
||||
| 场景 | threshold | stride | 说明 |
|
||||
|------|-----------|--------|------|
|
||||
| 精度优先 | 0.95 | 8 | 保守配置,density ~52-67% |
|
||||
| 平衡 | 0.9 | 8 | 默认配置,density ~40-53% |
|
||||
| 性能优先 | 0.9 | 4 | 激进配置,density ~37-52% |
|
||||
|
||||
## 测试命令
|
||||
|
||||
```bash
|
||||
# 基本测试
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_ruler.py \
|
||||
--data-dir tests/data/ruler_32k \
|
||||
--datasets niah_single_1 \
|
||||
--sample-indices 0 \
|
||||
--max-model-len 33792 \
|
||||
--sparse-policy XATTN_BSA \
|
||||
--sparse-threshold 0.9 \
|
||||
--sparse-stride 8 \
|
||||
--gpu-utilization 0.85
|
||||
|
||||
# 参数说明
|
||||
# --sparse-policy XATTN_BSA 启用 XAttention Block Sparse Attention
|
||||
# --sparse-threshold 0.9 累积注意力阈值 (0.9-0.99)
|
||||
# --sparse-stride 8 Q/K 下采样步长 (4/8/16)
|
||||
```
|
||||
|
||||
## DensityObserver 使用
|
||||
|
||||
```python
|
||||
from nanovllm.utils.density_observer import DensityObserver
|
||||
|
||||
# 启用并重置
|
||||
DensityObserver.enable()
|
||||
DensityObserver.complete_reset()
|
||||
|
||||
# ... 运行 inference (compute_prefill 自动记录) ...
|
||||
|
||||
# 获取结果
|
||||
summary = DensityObserver.get_summary()
|
||||
# {
|
||||
# "mode": "gpu_only",
|
||||
# "overall_density": 0.40, # 所有层的平均值
|
||||
# "per_layer_density": {0: 0.55, 1: 0.45, ...},
|
||||
# "num_layers": 32
|
||||
# }
|
||||
|
||||
# 获取最低 density
|
||||
min_layer, min_density = DensityObserver.get_min_density()
|
||||
|
||||
# 打印摘要
|
||||
DensityObserver.print_summary()
|
||||
# [DensityObserver] Mode: gpu_only
|
||||
# Overall density: 0.4012
|
||||
# Min density: 0.1846 (layer 5)
|
||||
# Num layers: 32
|
||||
```
|
||||
|
||||
## 相关文件
|
||||
|
||||
| 文件 | 说明 |
|
||||
|------|------|
|
||||
| `nanovllm/kvcache/sparse/xattn_bsa.py` | XAttention BSA Policy 实现 |
|
||||
| `nanovllm/utils/density_observer.py` | Density 统计 Observer |
|
||||
| `nanovllm/ops/xattn.py` | xattn_estimate 核心算法 |
|
||||
| `tests/test_ruler.py` | RULER benchmark 测试脚本 |
|
||||
@@ -229,16 +229,16 @@ class ModelRunner:
|
||||
|
||||
# GPU-only mode: pre-allocate policy metadata buffers
|
||||
# This avoids dynamic GPU memory allocation during forward pass
|
||||
if not config.enable_cpu_offload:
|
||||
num_heads = hf_config.num_attention_heads // self.world_size
|
||||
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
max_seq_len=config.max_model_len,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
# if not config.enable_cpu_offload:
|
||||
num_heads = hf_config.num_attention_heads // self.world_size
|
||||
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
||||
num_heads=num_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
max_seq_len=config.max_model_len,
|
||||
dtype=hf_config.torch_dtype,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
# Log policy info (handle both enum and None cases)
|
||||
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
|
||||
|
||||
@@ -47,6 +47,8 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
available_blocks: List[int],
|
||||
offload_engine: "OffloadEngine",
|
||||
ctx: PolicyContext,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> List[int]:
|
||||
"""Return all blocks - no sparsity."""
|
||||
# Update statistics (only for layer 0 to avoid overcounting)
|
||||
|
||||
@@ -142,6 +142,8 @@ class SparsePolicy(ABC):
|
||||
available_blocks: List[int],
|
||||
offload_engine: "OffloadEngine",
|
||||
ctx: PolicyContext,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select which KV blocks to load for the current query chunk.
|
||||
@@ -158,6 +160,8 @@ class SparsePolicy(ABC):
|
||||
to load KV to make selection decisions).
|
||||
ctx: PolicyContext with information about the current query
|
||||
chunk, layer, phase (prefill/decode), etc.
|
||||
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk
|
||||
|
||||
Returns:
|
||||
List of block IDs to load (must be a subset of available_blocks).
|
||||
|
||||
@@ -191,13 +191,26 @@ class QuestPolicy(SparsePolicy):
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
offload_engine: "OffloadEngine",
|
||||
ctx: PolicyContext,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select Top-K blocks based on query-key similarity bounds.
|
||||
|
||||
If query is not available (some prefill scenarios), falls back
|
||||
to loading all blocks.
|
||||
|
||||
Args:
|
||||
available_blocks: List of CPU block IDs
|
||||
offload_engine: OffloadEngine for loading KV (unused in Quest)
|
||||
ctx: PolicyContext with metadata
|
||||
q: Query tensor [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused in Quest, uses metadata instead)
|
||||
|
||||
Returns:
|
||||
Selected block IDs
|
||||
"""
|
||||
if self.metadata is None:
|
||||
raise RuntimeError(
|
||||
@@ -211,7 +224,7 @@ class QuestPolicy(SparsePolicy):
|
||||
if n <= self.config.threshold_blocks:
|
||||
return available_blocks
|
||||
|
||||
if ctx.query is None:
|
||||
if q is None:
|
||||
# No query available - cannot compute scores
|
||||
return available_blocks
|
||||
|
||||
@@ -221,11 +234,10 @@ class QuestPolicy(SparsePolicy):
|
||||
)
|
||||
|
||||
# Metadata is already on GPU, same device as query
|
||||
device = ctx.query.device
|
||||
device = q.device
|
||||
|
||||
# Compute upper bound scores
|
||||
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
|
||||
q = ctx.query
|
||||
# query shape: [seq_len, num_heads, head_dim] or [batch, seq_len, num_heads, head_dim]
|
||||
if q.dim() == 4:
|
||||
# Prefill: use mean over sequence length
|
||||
q = q.mean(dim=1) # [1, num_heads, head_dim]
|
||||
|
||||
@@ -17,6 +17,7 @@ import torch.cuda.nvtx as nvtx
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.utils.density_observer import DensityObserver
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||
@@ -134,6 +135,21 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
self._v_expanded: torch.Tensor | None = None
|
||||
self._max_seq_len: int = 0
|
||||
|
||||
# Pre-allocated mask buffer for chunked prefill (offload mode)
|
||||
# Stores BSA-level mask from select_blocks for use in compute_chunked_prefill
|
||||
# Shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
|
||||
self._prefill_mask_buffer: torch.Tensor | None = None
|
||||
self._current_mask_q_bsa: int = 0 # Current Q BSA blocks in buffer
|
||||
self._current_mask_k_bsa: int = 0 # Current K BSA blocks in buffer
|
||||
|
||||
# Selected block indices for mask extraction in compute_chunked_prefill
|
||||
# Stores the indices of selected CPU blocks in available_blocks
|
||||
self._selected_cpu_indices: List[int] = []
|
||||
self._bsa_per_cpu: int = 0 # BSA blocks per CPU block
|
||||
|
||||
#> Debug: store all K cache
|
||||
self._debug_k_full: torch.Tensor | None = None
|
||||
|
||||
def alloc_policy_metadata(
|
||||
self,
|
||||
num_heads: int,
|
||||
@@ -161,7 +177,17 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
dtype: Data type
|
||||
device: Target device
|
||||
"""
|
||||
# Only allocate if GQA (num_heads != num_kv_heads)
|
||||
# Pre-allocate mask buffer for chunked prefill (offload mode)
|
||||
# mask shape: [1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks]
|
||||
# This is needed regardless of GQA
|
||||
max_q_bsa_blocks = self.chunk_size // self.BSA_BLOCK_SIZE
|
||||
max_k_bsa_blocks = max_seq_len // self.BSA_BLOCK_SIZE
|
||||
mask_shape = (1, num_heads, max_q_bsa_blocks, max_k_bsa_blocks)
|
||||
self._prefill_mask_buffer = torch.empty(mask_shape, dtype=torch.bool, device=device)
|
||||
mask_memory_mb = num_heads * max_q_bsa_blocks * max_k_bsa_blocks / (1024 * 1024)
|
||||
logger.info(f"[XAttn] Pre-allocated mask buffer: shape={mask_shape}, memory={mask_memory_mb:.1f} MB")
|
||||
|
||||
# Only allocate GQA expansion buffers if GQA (num_heads != num_kv_heads)
|
||||
if num_heads == num_kv_heads:
|
||||
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
|
||||
return
|
||||
@@ -175,6 +201,9 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
|
||||
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
|
||||
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
|
||||
|
||||
#DEBUG : buffer for save all K cache.
|
||||
self._debug_k_full = torch.empty((1, num_heads, max_seq_len, head_dim), dtype=dtype, device=device)
|
||||
|
||||
# =========================================================================
|
||||
# GPU-only methods (non-chunked)
|
||||
@@ -258,6 +287,10 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
|
||||
# Set DensityObserver mode on first layer
|
||||
if layer_id == 0:
|
||||
DensityObserver.set_mode("gpu_only")
|
||||
|
||||
# Get dimensions
|
||||
total_q, num_heads, head_dim = q.shape
|
||||
total_kv, num_kv_heads, _ = k.shape
|
||||
@@ -315,6 +348,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
Q, K_exp,
|
||||
chunk_size=self.chunk_size,
|
||||
block_size=self.BSA_BLOCK_SIZE,
|
||||
stride=self.stride,
|
||||
threshold=self.threshold,
|
||||
use_triton=self.use_triton,
|
||||
causal=True,
|
||||
@@ -360,13 +394,8 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
# Update statistics (layer 0 only to avoid overcounting)
|
||||
if layer_id == 0:
|
||||
selected_blocks = mask_trimmed.sum().item()
|
||||
total_blocks = q_block_num * k_block_num * num_heads
|
||||
density = selected_blocks / total_blocks if total_blocks > 0 else 1.0
|
||||
logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, "
|
||||
f"k_blocks={k_block_num}, density={density:.1%}")
|
||||
# Record density for all layers via DensityObserver
|
||||
DensityObserver.record(layer_id, mask_trimmed, causal=True)
|
||||
|
||||
return output
|
||||
|
||||
@@ -400,33 +429,42 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
available_blocks: List[int],
|
||||
offload_engine: "OffloadEngine",
|
||||
ctx: PolicyContext,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Compute attention scores for all available blocks using flat_group_gemm,
|
||||
then use softmax_fuse_block_sum and find_blocks_chunked to select important blocks.
|
||||
|
||||
This method:
|
||||
1. Loads each K block from CPU
|
||||
2. Computes Q@K^T attention scores using XAttention stride reshape
|
||||
3. Applies softmax_fuse_block_sum to get block-level attention
|
||||
4. Uses find_blocks_chunked to select blocks based on threshold
|
||||
This method aligns with GPU-only xattn_estimate_chunked:
|
||||
1. Loads each K block from CPU (historical blocks)
|
||||
2. Gets current chunk K from prefill buffer
|
||||
3. Concatenates [historical K, current chunk K] for correct softmax normalization
|
||||
4. Uses causal=True with correct chunk_start for position-aware masking
|
||||
5. Only selects from historical blocks (current chunk is always full attention)
|
||||
|
||||
Args:
|
||||
available_blocks: List of CPU block IDs
|
||||
available_blocks: List of CPU block IDs (historical blocks only)
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
ctx: PolicyContext with query tensor and metadata
|
||||
ctx: PolicyContext with metadata
|
||||
q: Query tensor [seq_len, num_heads, head_dim] for current chunk
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim] for current chunk (used for estimation)
|
||||
|
||||
Returns:
|
||||
Selected block IDs based on attention threshold
|
||||
"""
|
||||
if not available_blocks or ctx.query is None:
|
||||
if q is None:
|
||||
return available_blocks
|
||||
|
||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum, find_blocks_chunked
|
||||
import math
|
||||
|
||||
layer_id = ctx.layer_id
|
||||
q = ctx.query # [seq_len, num_heads, head_dim]
|
||||
# Use passed q parameter instead of ctx.query
|
||||
|
||||
# Set DensityObserver mode on first layer
|
||||
if layer_id == 0:
|
||||
DensityObserver.set_mode("offload")
|
||||
|
||||
# Convert Q to [batch, heads, seq_len, head_dim]
|
||||
# q: [seq_len, num_heads, head_dim] -> [1, num_heads, seq_len, head_dim]
|
||||
@@ -453,18 +491,37 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
|
||||
q_reshaped_len = padded_q_len // self.stride
|
||||
|
||||
# Use a single slot for loading (synchronous mode for simplicity)
|
||||
# Get block size from context
|
||||
block_size = ctx.block_size # tokens per CPU block (e.g., 4096)
|
||||
reshaped_block_size = block_size // self.stride # e.g., 4096/8 = 512
|
||||
|
||||
# ============================================================
|
||||
# Step 1: Compute chunk_start and related parameters
|
||||
# ============================================================
|
||||
# chunk_start = Q's global position in reshaped space
|
||||
# Q starts at position: num_historical_blocks * block_size
|
||||
num_historical_blocks = len(available_blocks)
|
||||
historical_k_len = num_historical_blocks * block_size
|
||||
chunk_start = historical_k_len // self.stride # Q's position in reshaped space
|
||||
chunk_end = chunk_start + q_reshaped_len
|
||||
|
||||
# For valid Q length tracking (excluding padding)
|
||||
valid_q_reshaped = (q_len + self.stride - 1) // self.stride
|
||||
real_q_len = chunk_start + valid_q_reshaped
|
||||
|
||||
# ============================================================
|
||||
# Step 2: Pipeline load historical K blocks and compute attn_scores
|
||||
# ============================================================
|
||||
# Key design: Load each block, compute immediately, then release
|
||||
# This avoids storing all K in GPU memory at once (offload-friendly)
|
||||
slot = 0
|
||||
attn_scores_list = []
|
||||
BLOCK_N = 128
|
||||
k_alignment = self.stride * BLOCK_N
|
||||
|
||||
# Get block size from context
|
||||
block_size = ctx.block_size # tokens per CPU block (e.g., 1024)
|
||||
reshaped_block_size = block_size // self.stride # e.g., 1024/8 = 128
|
||||
|
||||
with nvtx.range("xattn_estimate_gemm"):
|
||||
with nvtx.range("xattn_estimate_historical"):
|
||||
for cpu_block_id in available_blocks:
|
||||
# Load only K from CPU to GPU (V not needed for estimate)
|
||||
# This saves 50% communication in the estimate phase
|
||||
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
@@ -472,125 +529,228 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
k_block = offload_engine.get_k_for_slot(slot)
|
||||
|
||||
# Convert K to [batch, heads, k_len, head_dim]
|
||||
# k_block: [1, block_size, num_kv_heads, head_dim] -> [1, num_kv_heads, block_size, head_dim]
|
||||
K_chunk = k_block.transpose(1, 2)
|
||||
|
||||
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
|
||||
|
||||
# Handle GQA: expand K heads to match Q heads
|
||||
num_kv_heads = K_chunk.shape[1]
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
#> DEBUG: save all K cache
|
||||
start_pos = cpu_block_id * block_size
|
||||
self._debug_k_full[:, :, start_pos:start_pos + block_size, :].copy_(K_chunk)
|
||||
|
||||
# # Pad K if necessary
|
||||
# k_len = K_chunk.shape[2]
|
||||
# if k_len < k_alignment:
|
||||
# pad_size = k_alignment - k_len
|
||||
# K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
|
||||
|
||||
# Pad K if necessary (k_len must be divisible by stride * BLOCK_N)
|
||||
k_len = K_chunk.shape[2]
|
||||
BLOCK_N = 128
|
||||
k_alignment = self.stride * BLOCK_N
|
||||
if k_len < k_alignment:
|
||||
# K too short, pad it
|
||||
pad_size = k_alignment - k_len
|
||||
K_chunk = torch.nn.functional.pad(K_chunk, (0, 0, 0, pad_size), value=0)
|
||||
|
||||
# Compute attention scores using flat_group_gemm_fuse_reshape
|
||||
# Output: [batch, heads, q_len/stride, k_len/stride]
|
||||
attn_chunk = flat_group_gemm_fuse_reshape(
|
||||
Q, K_chunk, self.stride,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
is_causal=False
|
||||
)
|
||||
attn_scores_list.append(attn_chunk)
|
||||
# # Compute attention scores for this historical block
|
||||
# # Historical blocks: all positions < Q, so Q always sees them (full attention)
|
||||
# # Use LOCAL chunk_start=0 to match test_xattn_k_chunked.py behavior
|
||||
# attn_chunk = flat_group_gemm_fuse_reshape(
|
||||
# Q, K_chunk, self.stride,
|
||||
# chunk_start=0, # Local: same as test
|
||||
# chunk_end=q_reshaped_len,
|
||||
# is_causal=False, # Historical K: all visible to Q
|
||||
# )
|
||||
# attn_scores_list.append(attn_chunk)
|
||||
|
||||
# Mark slot as done for reuse
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
|
||||
num_kv_heads = k.shape[1]
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
k_repeated = k.repeat_interleave(num_groups, dim=1).unsqueeze(0).transpose(1, 2) # [1, num_heads, historical_k_len, head_dim]
|
||||
|
||||
self._debug_k_full[:, :, historical_k_len:historical_k_len + q_len, :].copy_(k_repeated)
|
||||
|
||||
if layer_id == 0:
|
||||
__import__('pdb').set_trace()
|
||||
|
||||
# Concatenate all attention scores along K dimension
|
||||
# Each chunk: [1, heads, q_reshaped_len, block_reshaped_len]
|
||||
# Result: [1, heads, q_reshaped_len, total_k_reshaped_len]
|
||||
# ============================================================
|
||||
# Step 3: Get current chunk K and compute its attn_scores
|
||||
# ============================================================
|
||||
with nvtx.range("xattn_estimate_current"):
|
||||
# Current chunk K is in prefill buffer (already on GPU)
|
||||
k_curr, _ = offload_engine.get_prefill_buffer_slice(layer_id, q_len)
|
||||
# k_curr: [1, q_len, num_kv_heads, head_dim] -> [1, num_kv_heads, q_len, head_dim]
|
||||
K_current = k_curr.transpose(1, 2)
|
||||
|
||||
# Handle GQA for current chunk K
|
||||
num_kv_heads = K_current.shape[1]
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
K_current = K_current.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
# Pad current K if necessary
|
||||
curr_k_len = K_current.shape[2]
|
||||
padded_curr_k_len = ((curr_k_len + k_alignment - 1) // k_alignment) * k_alignment
|
||||
if padded_curr_k_len != curr_k_len:
|
||||
pad_size = padded_curr_k_len - curr_k_len
|
||||
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, pad_size), value=0)
|
||||
|
||||
# Compute attention scores for current chunk
|
||||
# IMPORTANT: Use LOCAL coordinates (0 to q_reshaped_len) for current chunk!
|
||||
# Because K_current only contains current chunk K (not full sequence),
|
||||
# block_n in kernel starts from 0. Using global chunk_start would cause
|
||||
# incorrect causal mask (Q would see K blocks it shouldn't).
|
||||
attn_current = flat_group_gemm_fuse_reshape(
|
||||
Q, K_current, self.stride,
|
||||
chunk_start=0, # Local: Q starts at 0 relative to K_current
|
||||
chunk_end=q_reshaped_len, # Local: Q ends at q_reshaped_len
|
||||
is_causal=True, # Current chunk: apply causal mask
|
||||
)
|
||||
attn_scores_list.append(attn_current)
|
||||
del K_current
|
||||
|
||||
# ============================================================
|
||||
# Step 4: Concatenate all attn_scores
|
||||
# ============================================================
|
||||
if not attn_scores_list:
|
||||
return available_blocks
|
||||
|
||||
attn_scores = torch.cat(attn_scores_list, dim=-1)
|
||||
# Free intermediate list immediately
|
||||
del attn_scores_list
|
||||
|
||||
# Step 2: Apply softmax_fuse_block_sum with hierarchical aggregation
|
||||
# Use smaller estimate_block_size (1024) for 15x faster softmax kernel,
|
||||
# then aggregate to CPU block level (4096).
|
||||
#
|
||||
# Hierarchical approach:
|
||||
# 1. softmax_fuse_block_sum with estimate_block_size (1024) -> fine-grained scores
|
||||
# 2. Aggregate: reshape + sum -> CPU block level scores
|
||||
# 3. Select blocks based on score + threshold (NOT mask + voting)
|
||||
cpu_block_size = block_size # e.g., 4096
|
||||
estimate_bs = self.estimate_block_size # e.g., 1024 (15x faster)
|
||||
ratio = cpu_block_size // estimate_bs # e.g., 4
|
||||
# Calculate padded K length for later use
|
||||
padded_k_len = historical_k_len + padded_curr_k_len
|
||||
|
||||
# Use estimate_block_size for softmax kernel (optimized)
|
||||
reshaped_est_bs = estimate_bs // self.stride # e.g., 1024/8 = 128
|
||||
norm = 1.0 # Normalization factor
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm # log2(e) with scaling
|
||||
segment_size = min(4096, reshaped_est_bs)
|
||||
# ============================================================
|
||||
# Step 5: Apply softmax_fuse_block_sum with causal=True
|
||||
# ============================================================
|
||||
cpu_block_size = block_size # e.g., 4096
|
||||
bsa_per_cpu = cpu_block_size // self.BSA_BLOCK_SIZE # e.g., 4096/128 = 32
|
||||
|
||||
# Use BSA_BLOCK_SIZE for block aggregation (aligned with GPU-only)
|
||||
reshaped_bsa_bs = self.BSA_BLOCK_SIZE // self.stride # e.g., 128/8 = 16
|
||||
norm = 1.0
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / self.stride / norm
|
||||
segment_size = min(4096, reshaped_bsa_bs)
|
||||
|
||||
with nvtx.range("xattn_estimate_softmax"):
|
||||
block_sums_fine = softmax_fuse_block_sum(
|
||||
block_sums = softmax_fuse_block_sum(
|
||||
attn_scores,
|
||||
reshaped_est_bs, # Use optimized estimate block size (128 vs 512)
|
||||
reshaped_bsa_bs,
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
real_q_len=q_reshaped_len,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
real_q_len=real_q_len,
|
||||
scale=scale,
|
||||
is_causal=False, # Historical blocks are all before current chunk
|
||||
is_causal=True, # Causal for consistent with GPU-only
|
||||
)
|
||||
# block_sums_fine shape: [batch, heads, q_est_blocks, k_est_blocks]
|
||||
# where k_est_blocks = len(available_blocks) * ratio
|
||||
# block_sums shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
|
||||
|
||||
# Step 3: Aggregate to CPU block level (hierarchical sum)
|
||||
# This is mathematically equivalent to direct computation but much faster
|
||||
batch_size_bs, num_heads_bs, q_est_blocks, k_est_blocks = block_sums_fine.shape
|
||||
num_cpu_blocks = len(available_blocks)
|
||||
# ============================================================
|
||||
# Step 6: Use find_blocks_chunked to generate BSA-level mask
|
||||
# ============================================================
|
||||
# Calculate BSA block indices
|
||||
q_bsa_blocks = (padded_q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
total_k_bsa_blocks = (padded_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
historical_k_bsa_blocks = num_historical_blocks * bsa_per_cpu
|
||||
|
||||
with nvtx.range("xattn_estimate_aggregate"):
|
||||
# Reshape: [batch, heads, q_est, k_est] -> [batch, heads, q_est, num_cpu, ratio]
|
||||
block_sums_coarse = block_sums_fine.view(
|
||||
batch_size_bs, num_heads_bs, q_est_blocks, num_cpu_blocks, ratio
|
||||
).sum(dim=-1) # [batch, heads, q_est_blocks, num_cpu_blocks]
|
||||
# current_index for find_blocks_chunked: Q's block offset
|
||||
q_start_bsa_block = historical_k_bsa_blocks # Q starts after historical K
|
||||
|
||||
# Sum over Q dimension to get total attention from Q chunk to each K block
|
||||
cpu_block_scores = block_sums_coarse.sum(dim=2) # [batch, heads, num_cpu_blocks]
|
||||
with nvtx.range("xattn_find_blocks"):
|
||||
mask = find_blocks_chunked(
|
||||
block_sums,
|
||||
current_index=q_start_bsa_block, # Q's position in BSA blocks
|
||||
threshold=self.threshold,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=True, # Causal for block-level mask
|
||||
)
|
||||
# mask shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
|
||||
|
||||
# Step 4: Select blocks using score + threshold (replaces mask + majority voting)
|
||||
# This is simpler and more direct than the original mask-based approach
|
||||
with nvtx.range("xattn_estimate_select"):
|
||||
# Average scores across heads (GQA-aware: all heads contribute equally)
|
||||
scores_per_block = cpu_block_scores.mean(dim=(0, 1)) # [num_cpu_blocks]
|
||||
# ============================================================
|
||||
# Step 7: Extract mask portions and record density
|
||||
# ============================================================
|
||||
B, H, Q_bsa, K_bsa_total = mask.shape
|
||||
|
||||
# Normalize to get attention distribution
|
||||
total_score = scores_per_block.sum()
|
||||
if total_score > 0:
|
||||
score_ratio = scores_per_block / total_score
|
||||
else:
|
||||
# Edge case: all zeros, select all blocks
|
||||
selected_block_ids = list(available_blocks)
|
||||
if layer_id == 0 and available_blocks:
|
||||
self._stats_total_available_blocks += len(available_blocks)
|
||||
self._stats_total_selected_blocks += len(selected_block_ids)
|
||||
self._stats_num_chunks += 1
|
||||
return selected_block_ids
|
||||
# Calculate valid Q blocks (excluding padding)
|
||||
valid_q_bsa = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
valid_curr_k_bsa = (curr_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
|
||||
# Sort by score (descending) and select until threshold is reached
|
||||
sorted_indices = torch.argsort(score_ratio, descending=True)
|
||||
cumsum = 0.0
|
||||
selected_indices = set()
|
||||
# 7a: Record historical blocks density
|
||||
# IMPORTANT: For historical blocks, apply causal mask to match GPU-only density calculation!
|
||||
# Q block i (global position = q_start_bsa_block + i) can see historical K block j
|
||||
# only if j <= q_start_bsa_block + i (causal constraint)
|
||||
mask_historical = mask[:, :, :valid_q_bsa, :historical_k_bsa_blocks]
|
||||
|
||||
for idx in sorted_indices.tolist():
|
||||
selected_indices.add(idx)
|
||||
cumsum += score_ratio[idx].item()
|
||||
if cumsum >= self.threshold:
|
||||
break
|
||||
if historical_k_bsa_blocks > 0:
|
||||
# Create causal mask for historical blocks
|
||||
# Q_global[i] = q_start_bsa_block + i, K[j] = j
|
||||
# Causal: j <= Q_global[i] => j <= q_start_bsa_block + i
|
||||
q_global_indices = torch.arange(valid_q_bsa, device=mask.device) + q_start_bsa_block
|
||||
k_indices = torch.arange(historical_k_bsa_blocks, device=mask.device)
|
||||
# Q at position q_global_indices[i] can see K at position k_indices[j] if k_indices[j] <= q_global_indices[i]
|
||||
causal_mask_historical = k_indices.unsqueeze(0) <= q_global_indices.unsqueeze(1) # [valid_q_bsa, historical_k_bsa_blocks]
|
||||
|
||||
# Map indices back to block IDs
|
||||
selected_block_ids = [available_blocks[i] for i in sorted(selected_indices)]
|
||||
# Count positions within causal mask only
|
||||
total_historical_causal = causal_mask_historical.sum().item() * B * H
|
||||
selected_historical = (mask_historical & causal_mask_historical.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
|
||||
if total_historical_causal > 0:
|
||||
DensityObserver.record_counts(layer_id, selected_historical, total_historical_causal)
|
||||
|
||||
# 7b: Record current chunk density (causal, to align with GPU-only mode)
|
||||
# Current chunk is the portion after historical blocks
|
||||
if valid_curr_k_bsa > 0:
|
||||
# Extract current chunk mask (only valid portion, not padded)
|
||||
mask_current = mask[:, :, :valid_q_bsa, historical_k_bsa_blocks:historical_k_bsa_blocks + valid_curr_k_bsa]
|
||||
|
||||
q_dim = mask_current.shape[2]
|
||||
k_dim = mask_current.shape[3]
|
||||
|
||||
# Create causal mask (lower triangular)
|
||||
# For current chunk: Q[i] can see K[j] where j <= i (standard causal)
|
||||
causal_mask = torch.tril(torch.ones(q_dim, k_dim, device=mask.device, dtype=torch.bool))
|
||||
|
||||
# Count positions within causal mask only
|
||||
total_current_causal = causal_mask.sum().item() * B * H
|
||||
selected_current = (mask_current & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
|
||||
if total_current_causal > 0:
|
||||
DensityObserver.record_counts(layer_id, selected_current, total_current_causal)
|
||||
|
||||
# Step 7.5: Save historical mask to pre-allocated buffer for compute_chunked_prefill
|
||||
# Use full Q_bsa (padded) for buffer, not valid_q_bsa
|
||||
mask_historical_full = mask[:, :, :, :historical_k_bsa_blocks]
|
||||
if self._prefill_mask_buffer is not None:
|
||||
# Only save historical portion of mask
|
||||
self._prefill_mask_buffer[:, :, :Q_bsa, :historical_k_bsa_blocks].copy_(mask_historical_full)
|
||||
self._current_mask_q_bsa = Q_bsa
|
||||
self._current_mask_k_bsa = historical_k_bsa_blocks
|
||||
|
||||
# ============================================================
|
||||
# Step 8: Aggregate mask to CPU block level (union of heads)
|
||||
# ============================================================
|
||||
# Only aggregate historical blocks (current chunk is always full attention)
|
||||
num_cpu_blocks = num_historical_blocks
|
||||
|
||||
with nvtx.range("xattn_aggregate_mask"):
|
||||
# Reshape historical mask: [B, H, Q_bsa, historical_k_bsa] -> [B, H, Q_bsa, num_cpu, bsa_per_cpu]
|
||||
# Use full Q_bsa (not valid_q_bsa) for aggregation
|
||||
mask_per_cpu = mask_historical_full.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu)
|
||||
|
||||
# Union across: bsa_per_cpu, Q_bsa, heads -> [B, num_cpu]
|
||||
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1) # [B, num_cpu]
|
||||
|
||||
# Get selected indices
|
||||
selected_indices = cpu_needed[0].nonzero().squeeze(-1).tolist()
|
||||
if isinstance(selected_indices, int):
|
||||
selected_indices = [selected_indices]
|
||||
|
||||
# Handle empty available_blocks case (first chunk)
|
||||
if available_blocks:
|
||||
selected_block_ids = [available_blocks[i] for i in selected_indices]
|
||||
else:
|
||||
selected_block_ids = []
|
||||
|
||||
# Always include first block (sink) and last block for safety
|
||||
if available_blocks and available_blocks[0] not in selected_block_ids:
|
||||
@@ -598,6 +758,14 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
if available_blocks and available_blocks[-1] not in selected_block_ids:
|
||||
selected_block_ids.append(available_blocks[-1])
|
||||
|
||||
# Record communication density (CPU block granularity) - only if there are historical blocks
|
||||
if available_blocks:
|
||||
DensityObserver.record_comm_density(
|
||||
layer_id,
|
||||
selected_cpu_blocks=len(selected_block_ids),
|
||||
total_cpu_blocks=len(available_blocks),
|
||||
)
|
||||
|
||||
# Update statistics (only for layer 0 to avoid overcounting)
|
||||
if layer_id == 0 and available_blocks:
|
||||
self._stats_total_available_blocks += len(available_blocks)
|
||||
@@ -610,7 +778,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
f"selected={len(selected_block_ids)}, chunk_density={chunk_density:.1%}")
|
||||
|
||||
# Free intermediate tensors to prevent memory leak
|
||||
del attn_scores, block_sums_fine, block_sums_coarse, cpu_block_scores, scores_per_block
|
||||
del attn_scores, block_sums, mask, mask_historical_full
|
||||
|
||||
return selected_block_ids
|
||||
|
||||
@@ -636,6 +804,10 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
2. Compute attention to current chunk
|
||||
3. Merge all results
|
||||
|
||||
Note: The BSA-level mask is saved in self._prefill_mask_buffer by select_blocks().
|
||||
Currently we use flash_attn_with_lse for computation (supports LSE merge).
|
||||
TODO: Optimize to use BSA kernel with the saved mask for per-head sparse attention.
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim] (unused, from prefill buffer)
|
||||
@@ -666,6 +838,11 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
# Use the pre-selected blocks directly
|
||||
cpu_block_table = selected_blocks
|
||||
|
||||
# Note: BSA mask is available in self._prefill_mask_buffer (saved by select_blocks)
|
||||
# Mask shape: [1, num_heads, Q_bsa, K_bsa] where Q_bsa = self._current_mask_q_bsa
|
||||
# Selected indices: self._selected_cpu_indices, bsa_per_cpu: self._bsa_per_cpu
|
||||
# TODO: Use this mask with BSA kernel for per-head sparse attention optimization
|
||||
|
||||
if cpu_block_table:
|
||||
with nvtx.range("xattn_compute_historical"):
|
||||
load_slots = list(range(offload_engine.num_ring_slots))
|
||||
|
||||
@@ -221,20 +221,19 @@ class Attention(nn.Module):
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
|
||||
selected_blocks = []
|
||||
if cpu_block_table:
|
||||
num_chunks = current_chunk_idx + 1
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=current_chunk_idx,
|
||||
num_query_chunks=num_chunks,
|
||||
layer_id=self.layer_id,
|
||||
query=q, # Pass query for sparse policies that need it
|
||||
is_prefill=True,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||
# Always call select_blocks even for first chunk (cpu_block_table may be empty)
|
||||
num_chunks = current_chunk_idx + 1
|
||||
policy_ctx = PolicyContext(
|
||||
query_chunk_idx=current_chunk_idx,
|
||||
num_query_chunks=num_chunks,
|
||||
layer_id=self.layer_id,
|
||||
query=q, # Pass query for sparse policies that need it
|
||||
is_prefill=True,
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size if cpu_block_table else 0,
|
||||
)
|
||||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
|
||||
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||
|
||||
# [DEBUG] Verify execution path
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
||||
@@ -320,7 +319,7 @@ class Attention(nn.Module):
|
||||
block_size=kvcache_manager.block_size,
|
||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||
)
|
||||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx, q, k)
|
||||
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||||
|
||||
# [DEBUG] Verify execution path
|
||||
|
||||
308
nanovllm/utils/density_observer.py
Normal file
308
nanovllm/utils/density_observer.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""
|
||||
DensityObserver - Sparse Attention Density 统计 Observer。
|
||||
|
||||
统计两种 density:
|
||||
1. Compute Density (计算密度): 基于 BSA block size (128)
|
||||
- density = selected_bsa_blocks / total_causal_bsa_blocks
|
||||
- GPU-only 和 Offload 模式应该一致
|
||||
|
||||
2. Communication Density (通信密度): 基于 CPU block size (如 4096)
|
||||
- comm_density = selected_cpu_blocks / total_cpu_blocks
|
||||
- 仅用于 Offload 模式,由于粒度更粗,必然 >= compute density
|
||||
|
||||
统计位置:
|
||||
- GPU-only: xattn_bsa.py compute_prefill() - 只记录 compute density
|
||||
- Offload: xattn_bsa.py select_blocks() - 记录两种 density
|
||||
|
||||
对于 Offload 模式的 Density 计算:
|
||||
- 不是简单的 avg 或 min
|
||||
- 而是 sum(selected) / sum(total),正确处理不同 chunk 大小的权重
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
import torch
|
||||
from nanovllm.utils.observer import Observer
|
||||
|
||||
|
||||
class DensityObserver(Observer):
|
||||
"""
|
||||
Sparse Attention Density Observer。
|
||||
|
||||
记录每层的 density,用于验证 GPU-only 和 Offload 模式的一致性。
|
||||
|
||||
使用方式:
|
||||
DensityObserver.enable()
|
||||
DensityObserver.complete_reset()
|
||||
# ... run inference ...
|
||||
DensityObserver.record(layer_id, mask, causal=True)
|
||||
# 或者使用累积模式 (offload):
|
||||
DensityObserver.record_counts(layer_id, selected, total)
|
||||
# ...
|
||||
DensityObserver.print_summary()
|
||||
"""
|
||||
|
||||
_enabled: bool = False # 默认禁用
|
||||
|
||||
# 每层的 compute density 记录 (BSA block 粒度)
|
||||
# key: layer_id, value: list of density values (每次 prefill chunk 一个)
|
||||
_layer_densities: Dict[int, List[float]] = {}
|
||||
|
||||
# 每层的 communication density 记录 (CPU block 粒度,仅 offload 模式)
|
||||
_layer_comm_densities: Dict[int, List[float]] = {}
|
||||
|
||||
# 累积模式: 记录 selected/total counts (用于 offload 模式)
|
||||
# 这样可以在所有 chunks 完成后正确计算 density = sum(selected) / sum(total)
|
||||
_layer_selected_counts: Dict[int, List[int]] = {}
|
||||
_layer_total_counts: Dict[int, List[int]] = {}
|
||||
|
||||
# Mask shape 记录 (用于调试)
|
||||
_last_q_blocks: int = 0
|
||||
_last_k_blocks: int = 0
|
||||
|
||||
# 模式标记
|
||||
_mode: str = "unknown" # "gpu_only" or "offload"
|
||||
|
||||
@classmethod
|
||||
def set_mode(cls, mode: str) -> None:
|
||||
"""设置当前模式 (gpu_only / offload)"""
|
||||
cls._mode = mode
|
||||
|
||||
@classmethod
|
||||
def record(
|
||||
cls,
|
||||
layer_id: int,
|
||||
mask: torch.Tensor,
|
||||
causal: bool = True,
|
||||
) -> float:
|
||||
"""
|
||||
记录一层的 density (适用于 GPU-only 模式)。
|
||||
|
||||
Args:
|
||||
layer_id: 层 ID
|
||||
mask: [batch, heads, q_blocks, k_blocks] boolean tensor
|
||||
causal: 是否考虑 causal mask (只计算下三角)
|
||||
|
||||
Returns:
|
||||
density 值
|
||||
"""
|
||||
if not cls._enabled:
|
||||
return 0.0
|
||||
|
||||
density = cls._compute_density(mask, causal)
|
||||
|
||||
# 记录
|
||||
if layer_id not in cls._layer_densities:
|
||||
cls._layer_densities[layer_id] = []
|
||||
cls._layer_densities[layer_id].append(density)
|
||||
|
||||
# 记录 mask shape
|
||||
cls._last_q_blocks = mask.shape[2]
|
||||
cls._last_k_blocks = mask.shape[3]
|
||||
|
||||
return density
|
||||
|
||||
@classmethod
|
||||
def record_counts(
|
||||
cls,
|
||||
layer_id: int,
|
||||
selected_blocks: int,
|
||||
total_blocks: int,
|
||||
) -> None:
|
||||
"""
|
||||
记录一层的 selected/total block counts (适用于 offload 累积模式)。
|
||||
|
||||
使用累积计数而不是直接计算 density,这样在所有 chunks 处理完后可以正确计算:
|
||||
overall_density = sum(selected) / sum(total)
|
||||
|
||||
这比 avg(density) 更准确,因为不同 chunk 的 Q 和 K 长度不同。
|
||||
|
||||
Args:
|
||||
layer_id: 层 ID
|
||||
selected_blocks: 这个 chunk 选中的 blocks 数量
|
||||
total_blocks: 这个 chunk 的 total possible blocks 数量
|
||||
"""
|
||||
if not cls._enabled:
|
||||
return
|
||||
|
||||
# 初始化列表
|
||||
if layer_id not in cls._layer_selected_counts:
|
||||
cls._layer_selected_counts[layer_id] = []
|
||||
if layer_id not in cls._layer_total_counts:
|
||||
cls._layer_total_counts[layer_id] = []
|
||||
|
||||
# 累积记录
|
||||
cls._layer_selected_counts[layer_id].append(selected_blocks)
|
||||
cls._layer_total_counts[layer_id].append(total_blocks)
|
||||
|
||||
@classmethod
|
||||
def record_comm_density(
|
||||
cls,
|
||||
layer_id: int,
|
||||
selected_cpu_blocks: int,
|
||||
total_cpu_blocks: int,
|
||||
) -> float:
|
||||
"""
|
||||
记录一层的 communication density (CPU block 粒度)。
|
||||
|
||||
Args:
|
||||
layer_id: 层 ID
|
||||
selected_cpu_blocks: 选中的 CPU blocks 数量
|
||||
total_cpu_blocks: 总 CPU blocks 数量
|
||||
|
||||
Returns:
|
||||
communication density 值
|
||||
"""
|
||||
if not cls._enabled:
|
||||
return 0.0
|
||||
|
||||
if total_cpu_blocks == 0:
|
||||
return 1.0
|
||||
|
||||
comm_density = selected_cpu_blocks / total_cpu_blocks
|
||||
|
||||
# 记录
|
||||
if layer_id not in cls._layer_comm_densities:
|
||||
cls._layer_comm_densities[layer_id] = []
|
||||
cls._layer_comm_densities[layer_id].append(comm_density)
|
||||
|
||||
return comm_density
|
||||
|
||||
@classmethod
|
||||
def _compute_density(cls, mask: torch.Tensor, causal: bool) -> float:
|
||||
"""计算 mask 的 density"""
|
||||
batch, heads, q_blocks, k_blocks = mask.shape
|
||||
|
||||
if causal:
|
||||
# 只计算下三角区域
|
||||
causal_mask = torch.tril(
|
||||
torch.ones(q_blocks, k_blocks, device=mask.device, dtype=torch.bool)
|
||||
)
|
||||
total_blocks = causal_mask.sum().item() * batch * heads
|
||||
selected_blocks = (mask & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
else:
|
||||
total_blocks = mask.numel()
|
||||
selected_blocks = mask.sum().item()
|
||||
|
||||
if total_blocks == 0:
|
||||
return 1.0
|
||||
|
||||
return selected_blocks / total_blocks
|
||||
|
||||
@classmethod
|
||||
def complete_reset(cls) -> None:
|
||||
"""重置所有统计"""
|
||||
cls._layer_densities = {}
|
||||
cls._layer_comm_densities = {}
|
||||
cls._layer_selected_counts = {}
|
||||
cls._layer_total_counts = {}
|
||||
cls._last_q_blocks = 0
|
||||
cls._last_k_blocks = 0
|
||||
cls._mode = "unknown"
|
||||
|
||||
@classmethod
|
||||
def get_per_layer_density(cls) -> Dict[int, float]:
|
||||
"""
|
||||
获取每层的 density。
|
||||
|
||||
对于累积模式 (offload): density = sum(selected) / sum(total)
|
||||
对于直接记录模式 (gpu_only): density = avg(density_values)
|
||||
"""
|
||||
result = {}
|
||||
|
||||
# 优先使用累积模式 (offload)
|
||||
if cls._layer_selected_counts:
|
||||
for layer_id in cls._layer_selected_counts:
|
||||
selected_list = cls._layer_selected_counts.get(layer_id, [])
|
||||
total_list = cls._layer_total_counts.get(layer_id, [])
|
||||
total_selected = sum(selected_list)
|
||||
total_total = sum(total_list)
|
||||
if total_total > 0:
|
||||
result[layer_id] = total_selected / total_total
|
||||
else:
|
||||
# 直接记录模式 (gpu_only)
|
||||
for layer_id, densities in cls._layer_densities.items():
|
||||
if densities:
|
||||
result[layer_id] = sum(densities) / len(densities)
|
||||
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_overall_density(cls) -> float:
|
||||
"""
|
||||
获取所有层的总体 compute density。
|
||||
|
||||
对于累积模式 (offload): density = sum(all_selected) / sum(all_total)
|
||||
对于直接记录模式 (gpu_only): density = avg(all_density_values)
|
||||
|
||||
注意: 总体 density 不是简单的 avg(per_layer_density),
|
||||
而是 sum(all_selected) / sum(all_total),这样可以正确处理权重。
|
||||
"""
|
||||
# 优先使用累积模式 (offload)
|
||||
if cls._layer_selected_counts:
|
||||
total_selected = 0
|
||||
total_total = 0
|
||||
for layer_id in cls._layer_selected_counts:
|
||||
total_selected += sum(cls._layer_selected_counts[layer_id])
|
||||
total_total += sum(cls._layer_total_counts.get(layer_id, []))
|
||||
if total_total > 0:
|
||||
return total_selected / total_total
|
||||
return 0.0
|
||||
|
||||
# 直接记录模式 (gpu_only)
|
||||
all_densities = []
|
||||
for densities in cls._layer_densities.values():
|
||||
all_densities.extend(densities)
|
||||
if not all_densities:
|
||||
return 0.0
|
||||
return sum(all_densities) / len(all_densities)
|
||||
|
||||
@classmethod
|
||||
def get_overall_comm_density(cls) -> float:
|
||||
"""获取所有层的平均 communication density"""
|
||||
all_densities = []
|
||||
for densities in cls._layer_comm_densities.values():
|
||||
all_densities.extend(densities)
|
||||
if not all_densities:
|
||||
return 0.0
|
||||
return sum(all_densities) / len(all_densities)
|
||||
|
||||
@classmethod
|
||||
def get_summary(cls) -> dict:
|
||||
"""返回统计摘要"""
|
||||
per_layer = cls.get_per_layer_density()
|
||||
return {
|
||||
"mode": cls._mode,
|
||||
"overall_density": cls.get_overall_density(),
|
||||
"per_layer_density": per_layer,
|
||||
"num_layers": len(per_layer),
|
||||
"last_mask_shape": {
|
||||
"q_blocks": cls._last_q_blocks,
|
||||
"k_blocks": cls._last_k_blocks,
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_min_density(cls) -> Tuple[int, float]:
|
||||
"""获取最低 density 的层和值"""
|
||||
per_layer = cls.get_per_layer_density()
|
||||
if not per_layer:
|
||||
return -1, 0.0
|
||||
min_layer = min(per_layer, key=per_layer.get)
|
||||
return min_layer, per_layer[min_layer]
|
||||
|
||||
@classmethod
|
||||
def print_summary(cls) -> None:
|
||||
"""打印人类可读的摘要"""
|
||||
per_layer = cls.get_per_layer_density()
|
||||
overall = cls.get_overall_density()
|
||||
min_layer, min_density = cls.get_min_density()
|
||||
overall_comm = cls.get_overall_comm_density()
|
||||
|
||||
print(f"[DensityObserver] Mode: {cls._mode}")
|
||||
print(f" Compute density: {overall:.4f} (min: {min_density:.4f} @ layer {min_layer})")
|
||||
if overall_comm > 0:
|
||||
print(f" Comm density: {overall_comm:.4f}")
|
||||
print(f" Num layers: {len(per_layer)}")
|
||||
# 输出 layer 0 的 density 用于对比
|
||||
if 0 in per_layer:
|
||||
print(f" Layer 0 density: {per_layer[0]:.6f}")
|
||||
@@ -41,6 +41,7 @@ from pathlib import Path
|
||||
from typing import List, Dict, Tuple, Optional
|
||||
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from nanovllm.utils.density_observer import DensityObserver
|
||||
|
||||
|
||||
# ============================================================
|
||||
@@ -381,6 +382,13 @@ def run_ruler_benchmark(
|
||||
print(f"Fresh LLM mode: {fresh_llm}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Enable DensityObserver for XAttention BSA
|
||||
if sparse_policy and sparse_policy.upper() == "XATTN_BSA":
|
||||
DensityObserver.enable()
|
||||
DensityObserver.complete_reset()
|
||||
if not json_output:
|
||||
print("[DensityObserver] Enabled for XAttention BSA")
|
||||
|
||||
# LLM initialization kwargs
|
||||
llm_kwargs = {
|
||||
"max_model_len": max_model_len,
|
||||
@@ -471,6 +479,14 @@ def run_ruler_benchmark(
|
||||
print(f"{'-'*54}")
|
||||
print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}")
|
||||
print(f"\nTime: {total_time:.1f}s")
|
||||
|
||||
# Print DensityObserver summary if enabled
|
||||
if sparse_policy and sparse_policy.upper() == "XATTN_BSA" and DensityObserver.is_enabled():
|
||||
print(f"\n{'='*60}")
|
||||
print("Density Statistics (XAttention BSA)")
|
||||
print(f"{'='*60}")
|
||||
DensityObserver.print_summary()
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
results = {
|
||||
|
||||
@@ -41,9 +41,9 @@ K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
|
||||
for i in range(q_len):
|
||||
if i % 2 == 0:
|
||||
Q[0, 0, i, :] = 1
|
||||
Q[0, 0, i, :] = 1 * (i // stride + 1)
|
||||
else:
|
||||
Q[0, 0, i, :] = 2
|
||||
Q[0, 0, i, :] = 2 * (i // stride + 1)
|
||||
|
||||
for i in range(kv_len):
|
||||
if i % 2 == 0:
|
||||
@@ -74,8 +74,11 @@ for k_chunk_idx in range(num_k_chunks):
|
||||
Q, K_chunk, stride,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
is_causal=False
|
||||
is_causal=True
|
||||
)
|
||||
|
||||
__import__('pdb').set_trace()
|
||||
|
||||
attn_scores_list.append(attn_chunk)
|
||||
|
||||
# 拼接所有 K chunks 的结果
|
||||
|
||||
Reference in New Issue
Block a user