Compare commits
5 Commits
a1c68a733e
...
1eb7521994
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1eb7521994 | ||
|
|
51bd678335 | ||
|
|
1ea5afd886 | ||
|
|
829b311c02 | ||
|
|
dd0472aea8 |
12
.ralph-tui/config.toml
Normal file
12
.ralph-tui/config.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
# Ralph TUI Configuration
|
||||
# Generated by setup wizard
|
||||
# See: ralph-tui config help
|
||||
|
||||
configVersion = "2.1"
|
||||
tracker = "json"
|
||||
agent = "claude"
|
||||
maxIterations = 30
|
||||
autoCommit = false
|
||||
|
||||
[trackerOptions]
|
||||
[agentOptions]
|
||||
@@ -42,6 +42,8 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
| [`docs/xattn_kv_chunking_density_test.md`](docs/xattn_kv_chunking_density_test.md) | 🧪 TEST: XAttention KV chunking density 验证,threshold=1.0 对齐,threshold<1.0 差异 10-13% |
|
||||
| [`docs/gpuonly_density_alignment_test.md`](docs/gpuonly_density_alignment_test.md) | ✅ TEST: Density 对齐验证 (GPU-only + Offload, 4K-64K),xattn_estimate vs KV chunking 完全一致 |
|
||||
| [`docs/xattn_memory_benchmark.md`](docs/xattn_memory_benchmark.md) | 📊 BENCH: XAttention 内存基准测试,Qwen3-0.6B 32K 在 24GB 显存可行 (gpu-util=0.28) |
|
||||
| [`docs/xattn_offload_stream_sync_fix.md`](docs/xattn_offload_stream_sync_fix.md) | 🐛 FIX: XAttention Offload stream 同步 bug,Pass1/Pass2 K 数据不一致,compute_stream 包装 |
|
||||
| [`docs/xattn_density_types.md`](docs/xattn_density_types.md) | 📊 Compute vs Comm density: BSA block (128) vs CPU block (4096) 粒度,聚合效应导致 comm=100% |
|
||||
|
||||
## Rules Index
|
||||
|
||||
@@ -106,6 +108,13 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
|
||||
|
||||
**Files**: `bench.py` (GPU), `bench_offload.py` (CPU offload), `bench_vllm.py` (comparison)
|
||||
|
||||
**GPU-only 测试模型选择**:
|
||||
|
||||
| GPU | 显存 | GPU-only 测试模型 |
|
||||
|-----|------|------------------|
|
||||
| RTX 3090 | 24GB | **Qwen3-0.6B** (必须,7B+ 模型会 OOM) |
|
||||
| A100 | 40GB+ | Qwen3-0.6B / 4B / 7B 均可 |
|
||||
|
||||
**Offload Mode Constraint**: When using `enable_cpu_offload=True`, only test with context length ≥ 32K. Shorter contexts don't exercise the chunked offload pipeline properly.
|
||||
|
||||
**Common Issues**:
|
||||
|
||||
152
docs/xattn_density_types.md
Normal file
152
docs/xattn_density_types.md
Normal file
@@ -0,0 +1,152 @@
|
||||
# XAttention Density Types: Compute vs Communication
|
||||
|
||||
XAttention BSA 统计两种不同粒度的 density,它们反映不同的优化效果。
|
||||
|
||||
## 两种 Density 的定义
|
||||
|
||||
### 1. Compute Density(计算密度)
|
||||
|
||||
**粒度**: BSA block (128 tokens)
|
||||
|
||||
**公式**:
|
||||
```
|
||||
compute_density = selected_bsa_blocks / total_causal_bsa_blocks
|
||||
```
|
||||
|
||||
**含义**: 实际需要计算 attention 的 blocks 占 causal 区域的比例。
|
||||
|
||||
**影响**: 决定 attention 计算量的减少。
|
||||
|
||||
### 2. Communication Density(通信密度)
|
||||
|
||||
**粒度**: CPU block (4096 tokens = 32 BSA blocks)
|
||||
|
||||
**公式**:
|
||||
```
|
||||
comm_density = selected_cpu_blocks / total_cpu_blocks
|
||||
```
|
||||
|
||||
**含义**: 需要从 CPU 传输到 GPU 的 blocks 占总 blocks 的比例。
|
||||
|
||||
**影响**: 决定 H2D 传输量的减少。
|
||||
|
||||
## 为什么 Comm Density 通常高于 Compute Density
|
||||
|
||||
### 聚合效应
|
||||
|
||||
由于 CPU block 粒度是 BSA block 的 32 倍,CPU block 选择使用 `any()` 聚合:
|
||||
|
||||
```python
|
||||
# BSA mask: [B, H, Q_bsa, K_bsa]
|
||||
# Reshape to CPU block level
|
||||
mask_per_cpu = mask.view(B, H, Q_bsa, num_cpu_blocks, bsa_per_cpu)
|
||||
# Any BSA block selected -> whole CPU block needed
|
||||
cpu_needed = mask_per_cpu.any(dim=-1).any(dim=2).any(dim=1)
|
||||
```
|
||||
|
||||
只要 CPU block 中**任意一个**:
|
||||
- Head 选择了该 block,或
|
||||
- Q position 选择了该 block,或
|
||||
- BSA sub-block 被选中
|
||||
|
||||
则整个 CPU block 都需要传输。
|
||||
|
||||
### 示例
|
||||
|
||||
| 场景 | Compute Density | Comm Density | 说明 |
|
||||
|------|-----------------|--------------|------|
|
||||
| 64K context, threshold=0.9 | 37% | 100% | 稀疏 blocks 均匀分布在所有 CPU blocks |
|
||||
| 32K context, threshold=0.9 | 50% | 100% | 同上 |
|
||||
|
||||
## 测试结果
|
||||
|
||||
### 测试命令
|
||||
|
||||
```bash
|
||||
# Offload 模式测试
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler.py \
|
||||
--model ~/models/Llama-3.1-8B-Instruct \
|
||||
--data-dir tests/data/ruler_64k \
|
||||
--datasets niah_single_1 \
|
||||
--num-samples 1 \
|
||||
--max-model-len 72000 \
|
||||
--enable-offload \
|
||||
--sparse-policy XATTN_BSA \
|
||||
--sparse-threshold 0.9
|
||||
```
|
||||
|
||||
### 输出示例
|
||||
|
||||
```
|
||||
[DensityObserver] Mode: offload
|
||||
Compute density: 0.3691 (min: 0.3691 @ layer 0)
|
||||
Comm density: 1.0000 (CPU block granularity)
|
||||
Savings ratio: 0.0% H2D transfer reduction
|
||||
Num layers: 1
|
||||
Layer 0 density: 0.369052
|
||||
```
|
||||
|
||||
## 关键发现
|
||||
|
||||
### 当前 XAttention 的通信优化局限
|
||||
|
||||
1. **Compute density 有效降低**: ~37% @ 64K context(计算量减少 63%)
|
||||
2. **Comm density 没有降低**: 100%(通信量没有减少)
|
||||
|
||||
### 原因分析
|
||||
|
||||
Attention pattern 的特点:
|
||||
- 不同 heads 关注不同位置
|
||||
- 不同 Q positions 关注不同 K positions
|
||||
- 稀疏选择分布在整个 sequence 上
|
||||
|
||||
这导致虽然每个 (head, Q, K) 组合只选择少量 blocks,但聚合后覆盖了所有 CPU blocks。
|
||||
|
||||
### 潜在优化方向
|
||||
|
||||
1. **Per-head block selection**: 每个 head 独立选择 CPU blocks
|
||||
2. **Block clustering**: 将相关 blocks 聚合到同一 CPU block
|
||||
3. **Dynamic block size**: 根据 attention pattern 动态调整 CPU block 大小
|
||||
|
||||
## DensityObserver API
|
||||
|
||||
### 启用和重置
|
||||
|
||||
```python
|
||||
from nanovllm.utils.density_observer import DensityObserver
|
||||
|
||||
DensityObserver.enable()
|
||||
DensityObserver.complete_reset()
|
||||
DensityObserver.set_mode("offload") # or "gpu_only"
|
||||
```
|
||||
|
||||
### 记录
|
||||
|
||||
```python
|
||||
# Compute density (GPU-only 模式自动记录)
|
||||
DensityObserver.record(layer_id, mask, causal=True)
|
||||
|
||||
# Comm density (Offload 模式在 select_blocks 中记录)
|
||||
DensityObserver.record_comm_density(layer_id, selected_cpu_blocks, total_cpu_blocks)
|
||||
```
|
||||
|
||||
### 获取结果
|
||||
|
||||
```python
|
||||
# 总体 density
|
||||
overall_compute = DensityObserver.get_overall_density()
|
||||
overall_comm = DensityObserver.get_overall_comm_density()
|
||||
|
||||
# Per-layer density
|
||||
per_layer_compute = DensityObserver.get_per_layer_density()
|
||||
per_layer_comm = DensityObserver.get_per_layer_comm_density()
|
||||
|
||||
# 打印摘要
|
||||
DensityObserver.print_summary()
|
||||
```
|
||||
|
||||
## 相关文件
|
||||
|
||||
- `nanovllm/utils/density_observer.py`: DensityObserver 实现
|
||||
- `nanovllm/kvcache/sparse/xattn_bsa.py`: XAttention BSA Policy(select_blocks 中记录 comm density)
|
||||
- `tests/test_ruler.py`: RULER benchmark 测试脚本
|
||||
307
docs/xattn_offload_stream_sync_fix.md
Normal file
307
docs/xattn_offload_stream_sync_fix.md
Normal file
@@ -0,0 +1,307 @@
|
||||
# XAttention Offload Stream Synchronization Fix
|
||||
|
||||
修复 XAttention BSA Policy 在 Offload 模式下的 CUDA stream 同步 bug。
|
||||
|
||||
**修复日期**: 2026-02-05
|
||||
**Commit**: `829b311`
|
||||
**影响文件**: `nanovllm/kvcache/sparse/xattn_bsa.py`, `nanovllm/kvcache/offload_engine.py`
|
||||
|
||||
---
|
||||
|
||||
## 问题描述
|
||||
|
||||
### 症状
|
||||
|
||||
在 Offload 模式下运行 RULER benchmark 时,XAttention BSA 的 `select_blocks` 方法中 Pass 1 和 Pass 2 从**同一个 CPU block** 加载的 K 数据不一致:
|
||||
|
||||
```
|
||||
Pass 1: K_chunk sum = 745472.00 (正确)
|
||||
Pass 2: K_chunk sum = 0.00 (错误,数据未加载完成)
|
||||
```
|
||||
|
||||
这导致 attention 计算结果错误,RULER 准确率下降。
|
||||
|
||||
### 复现条件
|
||||
|
||||
- 模式: Offload (`--enable-offload`)
|
||||
- Context: ≥ 32K tokens
|
||||
- 稀疏策略: `--sparse-policy XATTN_BSA`
|
||||
|
||||
---
|
||||
|
||||
## 根因分析
|
||||
|
||||
### Stream 配置回顾
|
||||
|
||||
nano-vllm 的 CPU offload 使用多个 CUDA streams 实现 pipeline:
|
||||
|
||||
| Stream | 用途 |
|
||||
|--------|------|
|
||||
| `slot_transfer_streams[i]` | H2D 传输 (CPU → GPU slot) |
|
||||
| `compute_stream` | Attention 计算 |
|
||||
| `prefill_offload_streams[i]` | D2H 传输 (GPU → CPU cache) |
|
||||
|
||||
### 同步机制
|
||||
|
||||
`wait_slot_layer(slot)` 使用 event 机制同步:
|
||||
|
||||
```python
|
||||
def wait_slot_layer(self, slot_idx: int):
|
||||
"""Make compute_stream wait for H2D transfer completion."""
|
||||
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
|
||||
```
|
||||
|
||||
### Bug 根因
|
||||
|
||||
在 `select_blocks` 方法中:
|
||||
|
||||
1. H2D 传输在 `slot_transfer_streams` 上执行
|
||||
2. `wait_slot_layer` 让 `compute_stream` 等待传输完成
|
||||
3. **但是** 后续的 compute kernels 在**默认 stream** 上执行,而不是 `compute_stream`
|
||||
|
||||
```python
|
||||
# Bug 代码
|
||||
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot) # compute_stream 等待
|
||||
|
||||
# 这些 kernel 在默认 stream 上运行,没有等待 H2D 完成!
|
||||
k_block = offload_engine.get_k_for_slot(slot)
|
||||
K_chunk = k_block.transpose(1, 2)
|
||||
# ... 后续计算 ...
|
||||
```
|
||||
|
||||
### 时序图
|
||||
|
||||
```
|
||||
slot_transfer_stream: [====H2D====]
|
||||
compute_stream: |wait|
|
||||
default_stream: [kernel1][kernel2] ← 没有等待!
|
||||
↑
|
||||
数据未就绪
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 修复方案
|
||||
|
||||
### 核心修改
|
||||
|
||||
将所有 estimate 阶段的 compute kernels 包装在 `with torch.cuda.stream(compute_stream):` 中:
|
||||
|
||||
```python
|
||||
# 修复后代码
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot) # compute_stream 等待
|
||||
|
||||
# 所有计算在 compute_stream 上执行
|
||||
with torch.cuda.stream(compute_stream):
|
||||
k_block = offload_engine.get_k_for_slot(slot)
|
||||
K_chunk = k_block.transpose(1, 2)
|
||||
# ... 后续计算 ...
|
||||
```
|
||||
|
||||
### 修复位置
|
||||
|
||||
`select_blocks` 方法中共 6 处需要修复:
|
||||
|
||||
| 位置 | 阶段 | 修复内容 |
|
||||
|------|------|----------|
|
||||
| Pass 1 历史 blocks | `xattn_estimate_pass1` | 历史 KV chunk 处理 |
|
||||
| Pass 1 当前 chunk | `xattn_estimate_pass1` | 当前 GPU 上的 K 处理 |
|
||||
| Step 2 合并 | `merge_softmax_stats` | softmax stats 合并 |
|
||||
| Pass 2 历史 blocks | `xattn_estimate_pass2` | 带全局 stats 的 block_sum |
|
||||
| Pass 2 当前 chunk | `xattn_estimate_pass2` | 当前 chunk 的 block_sum |
|
||||
| Step 4 block 选择 | `find_blocks_chunked` | 最终 block 选择 |
|
||||
|
||||
### 时序图(修复后)
|
||||
|
||||
```
|
||||
slot_transfer_stream: [====H2D====]
|
||||
compute_stream: |wait|[kernel1][kernel2]
|
||||
↑
|
||||
数据已就绪
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 代码变更详情
|
||||
|
||||
### 1. Pass 1 历史 blocks 处理
|
||||
|
||||
```python
|
||||
# Before (bug)
|
||||
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
k_block = offload_engine.get_k_for_slot(slot) # 默认 stream
|
||||
K_chunk = k_block.transpose(1, 2)
|
||||
# ... compute ...
|
||||
|
||||
# After (fixed)
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream): # 显式指定 stream
|
||||
k_block = offload_engine.get_k_for_slot(slot)
|
||||
K_chunk = k_block.transpose(1, 2)
|
||||
# ... compute ...
|
||||
```
|
||||
|
||||
### 2. 移除 STRONG SYNC
|
||||
|
||||
`offload_engine.py` 中移除了不必要的强同步:
|
||||
|
||||
```python
|
||||
# Removed from load_to_slot_layer() and load_k_only_to_slot_layer()
|
||||
# STRONG SYNC: Synchronize all prefill offload streams before H2D
|
||||
# for offload_stream in self.prefill_offload_streams:
|
||||
# offload_stream.synchronize()
|
||||
```
|
||||
|
||||
这些同步现在由 event 机制正确处理,不再需要阻塞式同步。
|
||||
|
||||
### 3. 其他清理
|
||||
|
||||
- 移除 DEBUG print 语句
|
||||
- 移除 `torch.save()` debug 代码
|
||||
- 合并多个 fallback 条件
|
||||
- 将 `chunk_size` 默认值从 16384 改为 4096(匹配 offload Q chunk size)
|
||||
|
||||
---
|
||||
|
||||
## 测试验证
|
||||
|
||||
### 测试命令
|
||||
|
||||
**GPU 0 - Offload 模式测试**:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_ruler.py \
|
||||
--model ~/models/Llama-3.1-8B-Instruct \
|
||||
--data-dir tests/data/ruler_32k \
|
||||
--datasets niah_single_1 \
|
||||
--num-samples 10 \
|
||||
--max-model-len 40960 \
|
||||
--enable-offload \
|
||||
--sparse-policy XATTN_BSA \
|
||||
--sparse-threshold 0.9
|
||||
```
|
||||
|
||||
**GPU 1 - GPU-only 模式测试**:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=1 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_ruler.py \
|
||||
--model ~/models/Qwen3-0.6B \
|
||||
--data-dir tests/data/ruler_32k \
|
||||
--datasets niah_single_1 \
|
||||
--num-samples 10 \
|
||||
--max-model-len 40960 \
|
||||
--sparse-policy XATTN_BSA \
|
||||
--sparse-threshold 0.9
|
||||
```
|
||||
|
||||
### 测试结果
|
||||
|
||||
| 模式 | 模型 | Context | Samples | Pass Rate | Density |
|
||||
|------|------|---------|---------|-----------|---------|
|
||||
| Offload | Llama-3.1-8B | 32K | 10/10 | **100%** | 9.53% |
|
||||
| GPU-only | Qwen3-0.6B | 32K | 10/10 | **100%** | 9.84% |
|
||||
|
||||
### Density 对齐验证
|
||||
|
||||
| 模式 | Layer 0 Density | 差异 |
|
||||
|------|-----------------|------|
|
||||
| GPU-only | 9.84% | - |
|
||||
| Offload | 9.53% | ~3% |
|
||||
|
||||
~3% 的差异是预期的,因为两种模式的 KV 累积模式不同:
|
||||
- GPU-only: 一次性处理所有 KV
|
||||
- Offload: 分 chunk 处理,每个 chunk 独立计算 softmax stats 后合并
|
||||
|
||||
---
|
||||
|
||||
## 技术细节
|
||||
|
||||
### 三阶段 KV Chunking 流程
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Stage 1: softmax_compute_partial_stats │
|
||||
│ └── 每个 KV chunk 独立计算 partial stats (m_i, l_i) │
|
||||
│ │
|
||||
│ Stage 2: merge_softmax_stats │
|
||||
│ └── Host 端合并所有 chunks: (m_global, l_global) │
|
||||
│ │
|
||||
│ Stage 3: softmax_normalize_and_block_sum │
|
||||
│ └── 使用全局 stats 归一化并计算 block sums │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### Stream 配置要求
|
||||
|
||||
| 操作类型 | Stream | 原因 |
|
||||
|----------|--------|------|
|
||||
| H2D 传输 | `slot_transfer_streams` | 异步传输,不阻塞计算 |
|
||||
| D2H 传输 | `prefill_offload_streams` | 异步 offload,不阻塞计算 |
|
||||
| Estimate kernels | `compute_stream` | 与 attention 计算共享,确保同步 |
|
||||
| Attention kernels | `compute_stream` | 主计算流 |
|
||||
|
||||
### Event 同步机制
|
||||
|
||||
```python
|
||||
# H2D 传输完成后记录 event
|
||||
self.ring_slot_ready[slot_idx].record(slot_transfer_stream)
|
||||
|
||||
# 计算前等待 H2D 完成
|
||||
self.compute_stream.wait_event(self.ring_slot_ready[slot_idx])
|
||||
|
||||
# 计算完成后记录 event(用于下一轮 H2D)
|
||||
self.ring_slot_compute_done[slot_idx].record(compute_stream)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 相关文档
|
||||
|
||||
- [`docs/architecture_guide.md`](architecture_guide.md): Stream 配置和 ring buffer 架构
|
||||
- [`docs/xattn_kv_chunking_kernels.md`](xattn_kv_chunking_kernels.md): 三阶段 softmax kernels
|
||||
- [`docs/gpuonly_density_alignment_test.md`](gpuonly_density_alignment_test.md): Density 对齐测试
|
||||
- [`docs/xattn_bsa_policy_design.md`](xattn_bsa_policy_design.md): XAttention BSA Policy 设计
|
||||
|
||||
---
|
||||
|
||||
## 经验总结
|
||||
|
||||
### 1. Stream 同步的隐蔽性
|
||||
|
||||
CUDA stream 同步 bug 很难发现:
|
||||
- 数据可能"大部分时间"正确(取决于时序)
|
||||
- 错误表现为随机/间歇性的结果偏差
|
||||
- 需要精确的 debug logging 才能定位
|
||||
|
||||
### 2. Event vs Synchronize
|
||||
|
||||
| 方法 | 优点 | 缺点 |
|
||||
|------|------|------|
|
||||
| `stream.wait_event()` | 非阻塞,保持 pipeline | 只同步指定 stream |
|
||||
| `stream.synchronize()` | 保证完成 | 阻塞整个 stream,破坏 pipeline |
|
||||
|
||||
**最佳实践**: 使用 event 进行精确同步,避免 synchronize 阻塞。
|
||||
|
||||
### 3. 调试方法
|
||||
|
||||
```python
|
||||
# 打印 tensor sum 验证数据一致性
|
||||
print(f"K_chunk sum = {K_chunk.sum().item()}")
|
||||
|
||||
# 保存中间结果进行离线比较
|
||||
torch.save({'K': K_chunk, 'layer': layer_id}, f'/tmp/debug_{pass}_{chunk}.pt')
|
||||
```
|
||||
@@ -96,7 +96,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
self,
|
||||
threshold: float = 0.95, # High threshold for accuracy testing
|
||||
stride: int = 8,
|
||||
chunk_size: int = 16384,
|
||||
chunk_size: int = 4096, # Match offload Q chunk size for density alignment
|
||||
block_size: int = 128,
|
||||
samples_per_chunk: int = 128,
|
||||
use_triton: bool = True,
|
||||
@@ -289,9 +289,11 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
Returns:
|
||||
Attention output [total_q, num_heads, head_dim]
|
||||
"""
|
||||
# When block_tables is provided (paged KV cache / prefix cache),
|
||||
# fallback to flash_attn as XAttention expects contiguous K, V
|
||||
if block_tables is not None:
|
||||
# Fallback to flash attention when:
|
||||
# 1. block_tables provided (paged KV cache / prefix cache) - XAttention expects contiguous K, V
|
||||
# 2. BSA kernel not available
|
||||
# 3. xattn_estimate not available
|
||||
if block_tables is not None or not BSA_AVAILABLE or not XATTN_AVAILABLE:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
@@ -304,32 +306,6 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
block_table=block_tables,
|
||||
)
|
||||
|
||||
if not BSA_AVAILABLE:
|
||||
# Fallback to flash attention if BSA not available
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if not XATTN_AVAILABLE:
|
||||
# Fallback to flash attention if xattn not available
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
|
||||
# Set DensityObserver mode on first layer
|
||||
@@ -477,8 +453,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1]
|
||||
causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool))
|
||||
selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
logger.info(f"[DEBUG GPU-only Layer0] mask_shape={mask_trimmed.shape}, "
|
||||
f"density={selected/causal_total:.6f}, selected={selected}, total={causal_total}")
|
||||
|
||||
DensityObserver.record(layer_id, mask_trimmed, causal=True)
|
||||
|
||||
return output
|
||||
@@ -633,98 +608,108 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
l_chunks = []
|
||||
num_kv_chunks = num_historical_blocks + 1 # +1 for current chunk
|
||||
|
||||
# Get compute_stream for all compute kernels (like attention computation)
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
with nvtx.range("xattn_estimate_pass1"):
|
||||
slot = 0
|
||||
|
||||
# Process historical blocks (from CPU)
|
||||
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||
# Load K from CPU
|
||||
# Load K from CPU (on slot_transfer_stream)
|
||||
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
# wait_slot_layer makes compute_stream wait for H2D transfer
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
k_block = offload_engine.get_k_for_slot(slot) # [1, block_size, num_kv_heads, head_dim]
|
||||
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
|
||||
# All compute kernels run on compute_stream (like attention computation)
|
||||
with torch.cuda.stream(compute_stream):
|
||||
k_block = offload_engine.get_k_for_slot(slot) # [1, block_size, num_kv_heads, head_dim]
|
||||
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
|
||||
|
||||
# GQA expansion
|
||||
num_kv_heads = K_chunk.shape[1]
|
||||
# GQA expansion
|
||||
num_kv_heads = K_chunk.shape[1]
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
# KV offset in reshaped space
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
|
||||
|
||||
# Compute raw attention scores
|
||||
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||
Q, K_chunk, self.stride,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False, # K 不完整,不能在这里用 causal
|
||||
)
|
||||
|
||||
# Compute partial stats (带 causal mask)
|
||||
m_partial, l_partial = softmax_compute_partial_stats(
|
||||
attn_weights_kv,
|
||||
reshaped_block_size,
|
||||
segment_size,
|
||||
scale,
|
||||
chunk_start=chunk_start,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
m_chunks.append(m_partial)
|
||||
l_chunks.append(l_partial)
|
||||
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
del attn_weights_kv
|
||||
|
||||
# Process current chunk K (already on GPU) on compute_stream
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# k: [seq_len, num_kv_heads, head_dim] -> [1, num_kv_heads, seq_len, head_dim]
|
||||
K_current = k.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
# GQA expansion for current chunk
|
||||
num_kv_heads = K_current.shape[1]
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||
K_current = K_current.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
# KV offset in reshaped space
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
|
||||
# Pad current K to alignment
|
||||
curr_k_len = K_current.shape[2]
|
||||
padded_curr_k_len = ((curr_k_len + alignment - 1) // alignment) * alignment
|
||||
if padded_curr_k_len != curr_k_len:
|
||||
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, padded_curr_k_len - curr_k_len), value=0)
|
||||
|
||||
# Compute raw attention scores
|
||||
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||
Q, K_chunk, self.stride,
|
||||
# KV offset for current chunk
|
||||
kv_offset_current = num_historical_blocks * kv_chunk_reshaped
|
||||
|
||||
# Compute attention scores for current chunk
|
||||
attn_weights_curr = flat_group_gemm_fuse_reshape(
|
||||
Q, K_current, self.stride,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False, # K 不完整,不能在这里用 causal
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
# Compute partial stats (带 causal mask)
|
||||
m_partial, l_partial = softmax_compute_partial_stats(
|
||||
attn_weights_kv,
|
||||
# Compute partial stats for current chunk
|
||||
m_partial_curr, l_partial_curr = softmax_compute_partial_stats(
|
||||
attn_weights_curr,
|
||||
reshaped_block_size,
|
||||
segment_size,
|
||||
scale,
|
||||
chunk_start=chunk_start,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
kv_offset=kv_offset_current,
|
||||
is_causal=True,
|
||||
)
|
||||
m_chunks.append(m_partial)
|
||||
l_chunks.append(l_partial)
|
||||
m_chunks.append(m_partial_curr)
|
||||
l_chunks.append(l_partial_curr)
|
||||
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
del attn_weights_kv
|
||||
|
||||
# Process current chunk K (already on GPU)
|
||||
# k: [seq_len, num_kv_heads, head_dim] -> [1, num_kv_heads, seq_len, head_dim]
|
||||
K_current = k.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
# GQA expansion for current chunk
|
||||
num_kv_heads = K_current.shape[1]
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
K_current = K_current.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
# Pad current K to alignment
|
||||
curr_k_len = K_current.shape[2]
|
||||
padded_curr_k_len = ((curr_k_len + alignment - 1) // alignment) * alignment
|
||||
if padded_curr_k_len != curr_k_len:
|
||||
K_current = torch.nn.functional.pad(K_current, (0, 0, 0, padded_curr_k_len - curr_k_len), value=0)
|
||||
|
||||
# KV offset for current chunk
|
||||
kv_offset_current = num_historical_blocks * kv_chunk_reshaped
|
||||
|
||||
# Compute attention scores for current chunk
|
||||
attn_weights_curr = flat_group_gemm_fuse_reshape(
|
||||
Q, K_current, self.stride,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
# Compute partial stats for current chunk
|
||||
m_partial_curr, l_partial_curr = softmax_compute_partial_stats(
|
||||
attn_weights_curr,
|
||||
reshaped_block_size,
|
||||
segment_size,
|
||||
scale,
|
||||
chunk_start=chunk_start,
|
||||
kv_offset=kv_offset_current,
|
||||
is_causal=True,
|
||||
)
|
||||
m_chunks.append(m_partial_curr)
|
||||
l_chunks.append(l_partial_curr)
|
||||
del attn_weights_curr
|
||||
del attn_weights_curr
|
||||
|
||||
# ================================================================
|
||||
# Step 2: Merge all partial stats
|
||||
# Step 2: Merge all partial stats (on compute_stream)
|
||||
# ================================================================
|
||||
with nvtx.range("xattn_estimate_merge"):
|
||||
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||
del m_chunks, l_chunks
|
||||
with torch.cuda.stream(compute_stream):
|
||||
with nvtx.range("xattn_estimate_merge"):
|
||||
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||
|
||||
del m_chunks, l_chunks
|
||||
|
||||
# ================================================================
|
||||
# Step 3: Second pass - normalize and compute block sums
|
||||
@@ -736,30 +721,61 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
|
||||
# Process historical blocks again
|
||||
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
|
||||
# Load K from CPU (on slot_transfer_stream)
|
||||
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
|
||||
# wait_slot_layer makes compute_stream wait for H2D transfer
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
k_block = offload_engine.get_k_for_slot(slot)
|
||||
K_chunk = k_block.transpose(1, 2)
|
||||
# All compute kernels run on compute_stream
|
||||
with torch.cuda.stream(compute_stream):
|
||||
k_block = offload_engine.get_k_for_slot(slot)
|
||||
K_chunk = k_block.transpose(1, 2)
|
||||
|
||||
num_kv_heads = K_chunk.shape[1]
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||
num_kv_heads = K_chunk.shape[1]
|
||||
if num_heads != num_kv_heads:
|
||||
num_groups = num_heads // num_kv_heads
|
||||
K_chunk = K_chunk.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped
|
||||
|
||||
# Recompute attention scores (trade-off: compute vs memory)
|
||||
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||
Q, K_chunk, self.stride,
|
||||
# Recompute attention scores (trade-off: compute vs memory)
|
||||
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||
Q, K_chunk, self.stride,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
# Normalize with global stats and compute block sums
|
||||
block_sum_kv = softmax_normalize_and_block_sum(
|
||||
attn_weights_kv,
|
||||
m_global,
|
||||
l_global,
|
||||
reshaped_block_size,
|
||||
segment_size,
|
||||
chunk_start=chunk_start,
|
||||
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||
scale=scale,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
attn_sum_per_kv.append(block_sum_kv)
|
||||
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
del attn_weights_kv
|
||||
|
||||
# Process current chunk on compute_stream
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Recompute attention scores for current chunk
|
||||
attn_weights_curr = flat_group_gemm_fuse_reshape(
|
||||
Q, K_current, self.stride,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
# Normalize with global stats and compute block sums
|
||||
block_sum_kv = softmax_normalize_and_block_sum(
|
||||
attn_weights_kv,
|
||||
block_sum_curr = softmax_normalize_and_block_sum(
|
||||
attn_weights_curr,
|
||||
m_global,
|
||||
l_global,
|
||||
reshaped_block_size,
|
||||
@@ -767,67 +783,42 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
chunk_start=chunk_start,
|
||||
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||
scale=scale,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
kv_offset=kv_offset_current,
|
||||
is_causal=True,
|
||||
)
|
||||
attn_sum_per_kv.append(block_sum_kv)
|
||||
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
del attn_weights_kv
|
||||
|
||||
# Process current chunk
|
||||
# Recompute attention scores for current chunk
|
||||
attn_weights_curr = flat_group_gemm_fuse_reshape(
|
||||
Q, K_current, self.stride,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
block_sum_curr = softmax_normalize_and_block_sum(
|
||||
attn_weights_curr,
|
||||
m_global,
|
||||
l_global,
|
||||
reshaped_block_size,
|
||||
segment_size,
|
||||
chunk_start=chunk_start,
|
||||
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||
scale=scale,
|
||||
kv_offset=kv_offset_current,
|
||||
is_causal=True,
|
||||
)
|
||||
attn_sum_per_kv.append(block_sum_curr)
|
||||
del attn_weights_curr, K_current
|
||||
attn_sum_per_kv.append(block_sum_curr)
|
||||
del attn_weights_curr, K_current
|
||||
|
||||
# ================================================================
|
||||
# Step 4: Concatenate block sums and select blocks
|
||||
# Step 4: Concatenate block sums and select blocks (on compute_stream)
|
||||
# ================================================================
|
||||
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
||||
del attn_sum_per_kv, m_global, l_global
|
||||
with torch.cuda.stream(compute_stream):
|
||||
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
||||
del attn_sum_per_kv, m_global, l_global
|
||||
|
||||
# Calculate q_block offset for find_blocks_chunked
|
||||
# This is the number of BSA blocks before Q in the full sequence
|
||||
num_blocks_per_chunk = q_reshaped_len // reshaped_block_size
|
||||
current_index = k_block_num - q_block_num # Q starts at this BSA block index
|
||||
# Calculate q_block offset for find_blocks_chunked
|
||||
# This is the number of BSA blocks before Q in the full sequence
|
||||
num_blocks_per_chunk = q_reshaped_len // reshaped_block_size
|
||||
current_index = k_block_num - q_block_num # Q starts at this BSA block index
|
||||
|
||||
with nvtx.range("xattn_find_blocks"):
|
||||
mask = find_blocks_chunked(
|
||||
attn_sum_concat,
|
||||
current_index=current_index,
|
||||
threshold=self.threshold,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=True,
|
||||
with nvtx.range("xattn_find_blocks"):
|
||||
mask = find_blocks_chunked(
|
||||
attn_sum_concat,
|
||||
current_index=current_index,
|
||||
threshold=self.threshold,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Apply causal mask post-processing (same as xattn.py lines 1300-1306)
|
||||
mask[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=torch.bool, device=mask.device), diagonal=0),
|
||||
mask[:, :, -q_block_num:, -q_block_num:],
|
||||
False,
|
||||
)
|
||||
|
||||
# Apply causal mask post-processing (same as xattn.py lines 1300-1306)
|
||||
mask[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=torch.bool, device=mask.device), diagonal=0),
|
||||
mask[:, :, -q_block_num:, -q_block_num:],
|
||||
False,
|
||||
)
|
||||
|
||||
# ================================================================
|
||||
# Step 5: Record density (only on layer 0)
|
||||
# ================================================================
|
||||
@@ -908,20 +899,21 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
if available_blocks and available_blocks[-1] not in selected_block_ids:
|
||||
selected_block_ids.append(available_blocks[-1])
|
||||
|
||||
# Record communication density
|
||||
if available_blocks:
|
||||
DensityObserver.record_comm_density(
|
||||
layer_id,
|
||||
selected_cpu_blocks=len(selected_block_ids),
|
||||
total_cpu_blocks=len(available_blocks),
|
||||
)
|
||||
|
||||
# Update statistics (only for layer 0 to avoid overcounting)
|
||||
if layer_id == 0 and available_blocks:
|
||||
self._stats_total_available_blocks += len(available_blocks)
|
||||
self._stats_total_selected_blocks += len(selected_block_ids)
|
||||
self._stats_num_chunks += 1
|
||||
|
||||
# Record communication density to DensityObserver
|
||||
# Comm density = selected_cpu_blocks / available_cpu_blocks
|
||||
# This is different from compute density (BSA block granularity)
|
||||
DensityObserver.record_comm_density(
|
||||
layer_id=layer_id,
|
||||
selected_cpu_blocks=len(selected_block_ids),
|
||||
total_cpu_blocks=len(available_blocks),
|
||||
)
|
||||
|
||||
# Log per-chunk density
|
||||
chunk_density = len(selected_block_ids) / len(available_blocks)
|
||||
logger.debug(f"[XAttn] chunk={ctx.query_chunk_idx}, available={len(available_blocks)}, "
|
||||
|
||||
@@ -266,14 +266,31 @@ class DensityObserver(Observer):
|
||||
return 0.0
|
||||
return sum(all_densities) / len(all_densities)
|
||||
|
||||
@classmethod
|
||||
def get_per_layer_comm_density(cls) -> Dict[int, float]:
|
||||
"""
|
||||
获取每层的 communication density (CPU block 粒度)。
|
||||
|
||||
Returns:
|
||||
Dict[layer_id, avg_comm_density]
|
||||
"""
|
||||
result = {}
|
||||
for layer_id, densities in cls._layer_comm_densities.items():
|
||||
if densities:
|
||||
result[layer_id] = sum(densities) / len(densities)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_summary(cls) -> dict:
|
||||
"""返回统计摘要"""
|
||||
per_layer = cls.get_per_layer_density()
|
||||
per_layer_comm = cls.get_per_layer_comm_density()
|
||||
return {
|
||||
"mode": cls._mode,
|
||||
"overall_density": cls.get_overall_density(),
|
||||
"per_layer_density": per_layer,
|
||||
"overall_compute_density": cls.get_overall_density(),
|
||||
"overall_comm_density": cls.get_overall_comm_density(),
|
||||
"per_layer_compute_density": per_layer,
|
||||
"per_layer_comm_density": per_layer_comm,
|
||||
"num_layers": len(per_layer),
|
||||
"last_mask_shape": {
|
||||
"q_blocks": cls._last_q_blocks,
|
||||
@@ -301,7 +318,9 @@ class DensityObserver(Observer):
|
||||
print(f"[DensityObserver] Mode: {cls._mode}")
|
||||
print(f" Compute density: {overall:.4f} (min: {min_density:.4f} @ layer {min_layer})")
|
||||
if overall_comm > 0:
|
||||
print(f" Comm density: {overall_comm:.4f}")
|
||||
# Offload mode: show both densities with explanation
|
||||
print(f" Comm density: {overall_comm:.4f} (CPU block granularity)")
|
||||
print(f" Savings ratio: {1 - overall_comm:.1%} H2D transfer reduction")
|
||||
print(f" Num layers: {len(per_layer)}")
|
||||
# 输出 layer 0 的 density 用于对比
|
||||
if 0 in per_layer:
|
||||
|
||||
@@ -386,8 +386,11 @@ def run_ruler_benchmark(
|
||||
if sparse_policy and sparse_policy.upper() == "XATTN_BSA":
|
||||
DensityObserver.enable()
|
||||
DensityObserver.complete_reset()
|
||||
# Set mode for correct density interpretation
|
||||
DensityObserver.set_mode("offload" if enable_cpu_offload else "gpu_only")
|
||||
if not json_output:
|
||||
print("[DensityObserver] Enabled for XAttention BSA")
|
||||
mode_str = "offload" if enable_cpu_offload else "gpu_only"
|
||||
print(f"[DensityObserver] Enabled for XAttention BSA (mode: {mode_str})")
|
||||
|
||||
# LLM initialization kwargs
|
||||
llm_kwargs = {
|
||||
|
||||
Reference in New Issue
Block a user