[WIP] Before refactor the compute)_chunked_prefill.

This commit is contained in:
Zijie Tian
2026-01-23 03:36:12 +08:00
parent edc006463b
commit ca32ea6f93
7 changed files with 914 additions and 114 deletions

View File

@@ -48,7 +48,7 @@ class Config:
# XAttention BSA specific parameters
sparse_block_size: int = 128 # Block size for BSA (tokens per block)
sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation
sparse_threshold: float = 0.9 # Cumulative attention threshold (0-1)
sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention)
sparse_use_triton: bool = True # Use Triton kernels for estimation
sparse_stride: int = 8 # Stride for Q/K downsampling

View File

@@ -124,42 +124,6 @@ class XAttentionBSAPolicy(SparsePolicy):
"""
return available_blocks
def _load_all_historical_kv(
self,
cpu_block_table: List[int],
layer_id: int,
offload_engine: "OffloadEngine",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Load all historical K/V from CPU to GPU.
Args:
cpu_block_table: List of CPU block IDs
layer_id: Current layer index
offload_engine: OffloadEngine instance
Returns:
(k_hist, v_hist) with shape [total_tokens, kv_heads, head_dim]
"""
if not cpu_block_table:
return None, None
k_list = []
v_list = []
for cpu_block_id in cpu_block_table:
k_block, v_block = offload_engine.load_block_full_from_cpu(
cpu_block_id, layer_id
)
k_list.append(k_block)
v_list.append(v_block)
# Concatenate: [num_blocks, block_size, kv_heads, head_dim] -> [total_tokens, kv_heads, head_dim]
k_hist = torch.cat(k_list, dim=0)
v_hist = torch.cat(v_list, dim=0)
return k_hist, v_hist
def compute_chunked_prefill(
self,
q: torch.Tensor,

View File

@@ -258,8 +258,12 @@ class Attention(nn.Module):
raise RuntimeError("sparse_policy is required for chunked decode")
# Check if policy supports decode phase
# If not, fallback to FullAttentionPolicy (e.g., XAttentionBSAPolicy only supports prefill)
if not sparse_policy.supports_decode:
raise RuntimeError(f"{sparse_policy} does not support decode phase")
from nanovllm.kvcache.sparse import FullAttentionPolicy
sparse_policy = FullAttentionPolicy()
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
f"falling back to FullAttentionPolicy")
# [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "

View File

@@ -1,90 +1,286 @@
# Task Plan: XAttention BSA 集成到 nanovllm
# Task Plan: XAttention BSA 真正的 Sparse 实现
## Goal
使用 `--sparse-policy XATTN_BSA` 运行 `test_ruler.py`,通过 `niah_single_1` 的前 5 个 sample。
实现 XAttentionBSAPolicy 的真正 sparse attention`select_blocks` 中使用 `xattn_estimate_chunked` 选择重要的 blocks然后复用 FullAttentionPolicy 的 ring buffer pipeline。
**验收标准**:
```bash
CUDA_VISIBLE_DEVICES=X PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--sparse-policy XATTN_BSA \
--task niah_single_1 \
--sample-ids 0,1,2,3,4
# 期望: 5/5 PASS
--datasets niah_single_1 \
--sample-indices 0,1,2,3,4
# 期望: 5/5 PASS,并且真正使用 sparse selection
```
## 当前状态
## 当前状态: Phase 1 - 代码分析完成
- `XAttentionBSAPolicy.compute_chunked_prefill` 实现 = `FullAttentionPolicy`(无 sparse
- `xattn_estimate_chunked` 已实现并验证
- BSA kernel (`block_sparse_attn`) 可用
## 核心设计理解
### 1. Block Size 关系
| 参数 | 值 | 说明 |
|------|-----|------|
| BSA block_size | 128 tokens | XAttention 的 block 粒度 |
| kvcache_block_size | 1024 tokens | CPU offload 的 block 粒度 |
| 比例 | 1:8 | 1 CPU block = 8 BSA blocks |
### 2. 特化条件(用户要求)
- BSA chunk_size = 外部 chunk_size
- 这样 `xattn_estimate_chunked` 返回的 mask 可以直接映射到 CPU block selection
- 复用现有的 `flash_attn_with_lse` + `merge_attention_outputs`
### 3. select_blocks 设计
```
select_blocks(available_blocks, offload_engine, ctx) -> List[int]
├─ 1. 从 metadata cache 获取下采样的 K
│ (在 on_prefill_offload 中收集)
├─ 2. 调用 xattn_estimate_chunked(Q, K_downsampled, q_start_pos)
│ 返回 mask: [B, H, q_blocks, k_blocks]
├─ 3. 将 BSA k_blocks 映射到 CPU block IDs
│ 每 8 个 BSA blocks = 1 CPU block
│ 只要 8 个中有任意一个被选中,就保留该 CPU block
└─ 4. 返回 selected_cpu_blocks
```
### 4. Metadata 存储策略
**方案 A**: 存储下采样的 K内存友好
```python
# on_prefill_offload 中:
k_downsampled = k_cache[::stride] # [block_size/stride, H, D]
self._k_cache[layer_id][cpu_block_id] = k_downsampled
```
**内存计算** (stride=8):
- 每 block: (1024/8) * 8 * 128 * 2 bytes = 256 KB
- 256 blocks * 32 layers = 2 GB (GPU 上用于快速估计)
**方案 B**: 存储 min/max metadata (更省内存)
```python
# on_prefill_offload 中:
k_min = k_cache[:num_valid].min(dim=0).values # [H, D]
k_max = k_cache[:num_valid].max(dim=0).values # [H, D]
```
- 但这需要不同的估计算法,不能直接用 xattn_estimate
**决定**: 使用方案 A下采样 K因为可以直接复用 xattn_estimate_chunked
## Phases
- [ ] Phase 1: 理解当前代码路径
- [ ] Phase 2: 实现 sparse mask 估计
- [ ] Phase 3: 实现 BSA sparse 计算
- [ ] Phase 4: 测试验证
- [x] Phase 1: 代码分析,理解当前实现
- [ ] Phase 2: 实现 on_prefill_offload 收集 K metadata
- [ ] Phase 3: 实现 select_blocks 中的 xattn estimation
- [ ] Phase 4: 实现 BSA block → CPU block 的映射
- [ ] Phase 5: 测试验证
## Phase 1: 理解当前代码路径
## Phase 2: on_prefill_offload 实现
### 1.1 确认 XATTN_BSA policy 是否被正确加载
- [ ] 检查 `test_ruler.py` 如何解析 `--sparse-policy XATTN_BSA`
- [ ] 检查 `KVCacheManager` 如何实例化 sparse_policy
- [ ] 运行 baseline 测试(`--sparse-policy FULL`)确认基础功能正常
### 需要修改的文件
- `nanovllm/kvcache/sparse/xattn_bsa.py`
### 1.2 确认数据流
- [ ] `compute_chunked_prefill` 的输入参数含义
- [ ] `offload_engine` 提供的数据访问接口
- [ ] 当前 chunk 的 K/V 如何获取
### 实现细节
## Phase 2: 实现 sparse mask 估计
```python
class XAttentionBSAPolicy(SparsePolicy):
def __init__(self, threshold=0.9, stride=8, ...):
self.threshold = threshold
self.stride = stride
self._k_cache: Dict[int, Dict[int, torch.Tensor]] = {}
# _k_cache[layer_id][cpu_block_id] = k_downsampled
### 2.1 调用 xattn_estimate_chunked
- [ ]`compute_chunked_prefill` 中加载历史 K
- [ ] 拼接历史 K + 当前 K
- [ ] 调用 `xattn_estimate_chunked(q, k_full, q_start_pos=...)`
- [ ] 获取 block mask
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
"""初始化 K cache 结构"""
self._k_cache = {layer_id: {} for layer_id in range(num_layers)}
self._num_kv_heads = num_kv_heads
self._head_dim = head_dim
### 2.2 处理参数对齐
- [ ] BSA block_size = 128
- [ ] chunk_size 与 kvcache_block_size 的关系
- [ ] q_start_pos 计算
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
"""收集下采样的 K 用于后续估计"""
# k_cache: [block_size, num_kv_heads, head_dim]
k_downsampled = k_cache[:num_valid_tokens:self.stride].clone()
# k_downsampled: [num_valid_tokens//stride, num_kv_heads, head_dim]
self._k_cache[layer_id][cpu_block_id] = k_downsampled
```
## Phase 3: 实现 BSA sparse 计算
## Phase 3: select_blocks 实现
### 3.1 方案选择
- 选项 A: 历史 + 当前分开计算,然后 merge
- 选项 B: 全部一起用 BSA 计算
### 关键问题
### 3.2 实现
- [ ] 构造 BSA 需要的输入格式
- [ ] 调用 `block_sparse_attn_func`
- [ ] 处理输出格式
1. **Q 从哪里来?**
- `ctx.query` 需要在调用 select_blocks 时传入
- 当前 FullAttentionPolicy 传递 `query=None`
- 需要修改 compute_chunked_prefill 传递真实的 Q
## Phase 4: 测试验证
2. **Q 的格式转换**
- 输入 Q: [seq_len, num_heads, head_dim]
- xattn 需要: [B, H, q_len, D]
- 转换: `q.unsqueeze(0).transpose(1, 2)`
### 4.1 单元测试
- [ ] 验证 sparse mask 与 `test_xattn_estimate_chunked.py` 一致
3. **K 的组装**
-`_k_cache[layer_id]` 获取各 block 的下采样 K
-`available_blocks` 顺序 cat 起来
- 结果: [B, H, total_k_downsampled, D]
### 4.2 集成测试
- [ ] 运行验收命令
- [ ] 5/5 PASS
### 实现草案
## Key Questions
```python
def select_blocks(self, available_blocks, offload_engine, ctx):
if not available_blocks or ctx.query is None:
return available_blocks
1. 历史 K 如何高效加载?(全量 vs 按需)
2. BSA causal mask 如何处理?(历史 non-causal + 当前 causal
layer_id = ctx.layer_id
# 1. 组装下采样的 K
k_list = []
for cpu_block_id in available_blocks:
if cpu_block_id in self._k_cache[layer_id]:
k_list.append(self._k_cache[layer_id][cpu_block_id])
if not k_list:
return available_blocks
k_hist = torch.cat(k_list, dim=0) # [total_tokens/stride, H, D]
k_hist = k_hist.unsqueeze(0).transpose(1, 2) # [1, H, k_len, D]
# 2. 准备 Q
q = ctx.query # [seq_len, num_heads, head_dim]
q = q.unsqueeze(0).transpose(1, 2) # [1, H, q_len, D]
# GQA 扩展(如果需要)
if q.shape[1] != k_hist.shape[1]:
num_groups = q.shape[1] // k_hist.shape[1]
k_hist = k_hist.repeat_interleave(num_groups, dim=1)
# 3. 计算 q_start_pos
q_start_pos = len(available_blocks) * ctx.block_size
# 4. 调用 xattn_estimate_chunked
# 注意K 已经是下采样的,需要调整参数
attn_sum, mask = xattn_estimate_chunked(
q, k_hist,
q_start_pos=q_start_pos // self.stride, # 调整到下采样空间
block_size=self.BSA_BLOCK_SIZE // self.stride, # 16
stride=1, # K 已经下采样
threshold=self.threshold,
chunk_size=q.shape[2], # 与 Q 长度一致
use_triton=self.use_triton,
)
# 5. 从 mask 提取 CPU block IDs
# mask: [1, H, q_blocks, k_blocks]
# 对所有 heads 取 OR
selected_mask = mask.any(dim=1).squeeze(0) # [q_blocks, k_blocks]
# 对所有 q_blocks 取 OR只要任意 Q 位置需要这个 K block
selected_k_mask = selected_mask.any(dim=0) # [k_blocks]
# 6. 映射 BSA blocks → CPU blocks
# 每个 CPU block = 8 BSA blocks (block_size=1024, BSA_block=128)
bsa_to_cpu_ratio = ctx.block_size // self.BSA_BLOCK_SIZE # 8
num_cpu_blocks = len(available_blocks)
selected_cpu_indices = set()
for bsa_idx in selected_k_mask.nonzero(as_tuple=True)[0].tolist():
cpu_idx = bsa_idx // bsa_to_cpu_ratio
if cpu_idx < num_cpu_blocks:
selected_cpu_indices.add(cpu_idx)
selected_blocks = [available_blocks[i] for i in sorted(selected_cpu_indices)]
logger.info(f"[XAttn] select_blocks: {len(available_blocks)} -> {len(selected_blocks)} "
f"({100*len(selected_blocks)/len(available_blocks):.1f}%)")
return selected_blocks
```
## Phase 4: compute_chunked_prefill
### 关键修改
1. **传递真实的 Q 给 select_blocks**
- 修改 PolicyContext 构造,设置 `query=q`
2. **复用 FullAttentionPolicy 的 pipeline**
- 继承 FullAttentionPolicy 而不是 SparsePolicy
- 或者直接调用父类方法
### 方案对比
**方案 A**: XAttentionBSAPolicy 继承 FullAttentionPolicy
```python
class XAttentionBSAPolicy(FullAttentionPolicy):
# 只需要 override select_blocks 和 on_prefill_offload
# compute_chunked_prefill 直接用父类的
```
**方案 B**: 独立实现,调用相同的 pipeline 代码
```python
class XAttentionBSAPolicy(SparsePolicy):
def compute_chunked_prefill(self, q, k, v, ...):
# 复制 FullAttentionPolicy 的代码
# 但修改 PolicyContext 传递 query=q
```
**决定**: 使用方案 B因为需要在 compute_chunked_prefill 中修改 PolicyContext
## Phase 5: 测试
### 单元测试
```bash
# 测试 select_blocks 的 sparsity
python -c "
from nanovllm.kvcache.sparse.xattn_bsa import XAttentionBSAPolicy
policy = XAttentionBSAPolicy(threshold=0.9)
# ... 测试代码
"
```
### 集成测试
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_ruler.py \
--model ~/models/Llama-3.1-8B-Instruct \
--enable-offload \
--sparse-policy XATTN_BSA \
--datasets niah_single_1 \
--sample-indices 0,1,2,3,4
```
## Key Decisions
| 决策 | 理由 |
|------|------|
| 使用下采样 K 作为 metadata | 可以直接复用 xattn_estimate_chunked |
| stride=8 | 平衡内存和精度 |
| BSA blocks → CPU blocks 映射用 OR | 只要有一个 BSA block 被选中就保留 |
| 继承 FullAttentionPolicy 的 pipeline | 复用已验证的 ring buffer 流程 |
## Files to Modify
| 文件 | 修改 |
|------|------|
| `nanovllm/kvcache/sparse/xattn_bsa.py` | 主要实现initialize, on_prefill_offload, select_blocks |
## 注意事项
1. **GQA 处理**: Llama-3.1-8B 有 32 query heads, 8 kv heads需要在估计时扩展 K
2. **内存管理**: `_k_cache` 存储在 GPU需要在 reset() 时清理
3. **Triton 兼容性**: xattn_estimate_chunked 有 Triton bug可能需要用 PyTorch fallback
4. **边界条件**: 第一个 chunk (available_blocks=[]) 时直接返回空列表
## Errors Encountered
(待填充)
## Status
**Currently in Phase 1** - 等待用户确认后开始
## 待讨论
请确认:
1. 这个 goal 和验收标准是否正确?
2. 我使用哪个 GPU 运行测试?
**Currently in Phase 1** - 代码分析完成,准备开始 Phase 2 实现

334
tests/test_xattn_bsa.py Normal file
View File

@@ -0,0 +1,334 @@
"""
Test XAttention + BSA with RULER benchmark data.
Tests XAttention sparse attention correctness using RULER NIAH task.
Attention methods:
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
- Decode: FlashAttention (always, since q_len=1)
Usage (in compass conda env with BSA available):
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
# Test with XAttention + BSA for prefill (default)
python tests/test_xattn_bsa.py --prefill-method xattn
# Test with FlashAttention for prefill (baseline)
python tests/test_xattn_bsa.py --prefill-method flash
# Test specific sample(s)
python tests/test_xattn_bsa.py --sample-id 0
python tests/test_xattn_bsa.py --sample-ids 0,1,2
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
and new `past_key_values` API).
"""
import argparse
import json
import sys
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
from nanovllm.ops.xattn import xattn_estimate
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
# ============================================================
# XAttention + BSA Functions
# ============================================================
def expand_kv_for_gqa(key_states, value_states, num_heads):
"""Expand KV for Grouped Query Attention."""
num_kv_heads = key_states.shape[1]
if num_heads == num_kv_heads:
return key_states, value_states
num_groups = num_heads // num_kv_heads
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
"""Standard FlashAttention."""
from flash_attn import flash_attn_func
q = query_states.transpose(1, 2)
k = key_states.transpose(1, 2)
v = value_states.transpose(1, 2)
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
"""XAttention + BSA sparse attention."""
from block_sparse_attn import block_sparse_attn_func
batch_size, num_heads, q_len, head_dim = query_states.shape
k_len = key_states.shape[2]
_, mask = xattn_estimate(
query_states, key_states,
chunk_size=16384, block_size=128, threshold=threshold,
use_triton=True, causal=True,
)
q_block_num = (q_len + 127) // 128
k_block_num = (k_len + 127) // 128
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
__import__('pdb').set_trace()
output = block_sparse_attn_func(
q, k, v,
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
torch.ones(num_heads, dtype=torch.int32, device=q.device),
None,
mask[:, :, :q_block_num, :k_block_num].contiguous(),
q_len, k_len,
p_dropout=0.0, deterministic=True, is_causal=True,
)
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
DEBUG = False # Set to True to enable debugging
def create_patched_forward(prefill_method="xattn", threshold=0.9):
"""Create patched forward with configurable prefill method.
Args:
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
Note:
- Prefill (q_len > 1): Uses specified prefill_method
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
"""
call_count = [0] # Mutable to track calls across layers
def patched_forward(
self,
hidden_states,
position_embeddings=None,
attention_mask=None,
past_key_value=None, # Old API (transformers < 4.57)
past_key_values=None, # New API (transformers >= 4.57)
cache_position=None,
**kwargs
):
# Handle both old and new transformers API
kv_cache = past_key_values if past_key_values is not None else past_key_value
bsz, q_len, _ = hidden_states.size()
num_heads = self.config.num_attention_heads
num_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
# Compute Q, K, V projections
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
# Apply rotary position embedding
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Handle KV cache
if kv_cache is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = kv_cache.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# Expand KV for GQA
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
# Debug output
if DEBUG and self.layer_idx == 0:
call_count[0] += 1
if call_count[0] <= 5:
phase = "prefill" if q_len > 1 else "decode"
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
print(f" kv_cache is None: {kv_cache is None}")
# Choose attention method:
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
# - Decode (q_len = 1): Always use FlashAttention
is_prefill = q_len > 1
if is_prefill and prefill_method == "xattn":
# Prefill with XAttention + BSA (sparse)
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
else:
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
# Note: For decode (q_len=1), causal=False since single query attends to all KV
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
return attn_output, None
return patched_forward
# ============================================================
# Data & Evaluation
# ============================================================
def load_samples(filepath, indices=None):
"""Load samples from JSONL file."""
samples = []
with open(filepath) as f:
for i, line in enumerate(f):
if indices is None or i in indices:
sample = json.loads(line)
sample["_idx"] = i
samples.append(sample)
return samples
def string_match_all(output_text, expected_list):
"""RULER metric: fraction of expected values found in output."""
output_lower = output_text.lower().replace('\n', ' ')
if not expected_list:
return 1.0
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
# ============================================================
# Test
# ============================================================
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
"""Test attention methods using RULER data.
Args:
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
"""
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
print("=" * 60)
print("RULER NIAH Attention Test")
print("=" * 60)
print(f"Data: {data_file}")
print(f"Samples: {sample_ids}")
print(f"Prefill method: {prefill_desc}")
print(f"Decode method: FlashAttention (always)")
if prefill_method == "xattn":
print(f"XAttention threshold: {threshold}")
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
if not samples:
print("No samples found!")
return False
print(f"Loaded {len(samples)} samples")
# Load model
print(f"\nLoading model: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="cuda",
attn_implementation="eager", # Will be patched
)
model.eval()
# Patch all layers
print(f"Patching attention layers...")
print(f" - Prefill: {prefill_desc}")
print(f" - Decode: FlashAttention")
for idx, layer in enumerate(model.model.layers):
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
layer.self_attn, type(layer.self_attn)
)
total_score = 0.0
results = []
for sample in samples:
idx = sample["_idx"]
prompt = sample["input"]
expected = sample["outputs"]
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
num_tokens = inputs["input_ids"].shape[1]
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
print(f"Expected: {expected}")
with torch.no_grad():
output = model.generate(
inputs["input_ids"],
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
score = string_match_all(output_text, expected)
total_score += score
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
print(f"Output: '{output_text[:100]}...'")
print(f"Result: {status} (score={score:.2f})")
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
avg_score = total_score / len(samples)
passed = sum(1 for r in results if r["passed"])
print(f"\n{'='*60}")
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
print(f"{'='*60}")
return avg_score >= 0.5
def main():
parser = argparse.ArgumentParser(
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
)
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
parser.add_argument("--max-new-tokens", type=int, default=50)
# Keep old option for backwards compatibility
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
args = parser.parse_args()
model_path = args.model.replace("~", "/home/zijie")
# Handle deprecated --no-xattn option
prefill_method = args.prefill_method
if args.no_xattn:
prefill_method = "flash"
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
if args.sample_id is not None:
sample_ids = [args.sample_id]
elif args.sample_ids:
sample_ids = [int(x) for x in args.sample_ids.split(",")]
else:
sample_ids = [0]
# Check BSA availability if using xattn
if prefill_method == "xattn":
try:
from block_sparse_attn import block_sparse_attn_func
print("✓ BSA (Block Sparse Attention) available")
except ImportError:
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
sys.exit(1)
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
print("\ntest_xattn_bsa: PASSED")
else:
print("\ntest_xattn_bsa: FAILED")
sys.exit(1)
if __name__ == "__main__":
main()

259
tests/test_xattn_chunked.py Normal file
View File

@@ -0,0 +1,259 @@
"""
Test: Compare xattn_estimate vs xattn_estimate_chunked
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
Uses real QKV data captured from model inference.
"""
import sys
import os
import torch
import warnings
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
# ============================================================
# Configuration
# ============================================================
BLOCK_SIZE = 64
STRIDE = 4
THRESHOLD = 0.9
CHUNK_SIZE = 4096
# Default QKV data directory (relative to project root)
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
# ============================================================
# Utility Functions
# ============================================================
def load_qkv(path):
"""Load saved QKV data."""
data = torch.load(path, map_location="cpu", weights_only=False)
print(f"Loaded: {path}")
print(f" Query shape: {data['query'].shape}")
print(f" Key shape: {data['key'].shape}")
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
return data
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
"""Compare two masks and report differences."""
if mask1.shape != mask2.shape:
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
return False
diff = (mask1 != mask2).sum().item()
total = mask1.numel()
match_rate = (total - diff) / total * 100
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
if diff > 0:
diff_indices = torch.where(mask1 != mask2)
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
return diff == 0
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
"""
Run xattn_estimate_chunked with EXTERNAL chunking.
This simulates how chunked prefill should be used in practice.
"""
batch_size, num_heads, q_len, head_dim = query.shape
_, _, k_len, _ = key.shape
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
# If Q fits in one chunk, call directly
if q_len <= chunk_size:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
return xattn_estimate_chunked(
query, key,
q_start_pos=q_start_pos,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# External chunking: split Q and call for each chunk
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
print(f" External chunking: {num_q_chunks} chunks")
combined_attn_sum = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=query.dtype, device=query.device
)
combined_mask = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=torch.bool, device=query.device
)
q_block_offset = 0
for q_chunk_idx in range(num_q_chunks):
q_chunk_start = q_chunk_idx * chunk_size
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
# For causal attention, K accumulates up to current Q position
k_end = q_start_pos + q_chunk_end
k_chunk = key[:, :, :k_end, :]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
q_chunk, k_chunk,
q_start_pos=q_start_pos + q_chunk_start,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# Place chunk results into combined output
chunk_q_blocks = mask_chunk.shape[2]
chunk_k_blocks = mask_chunk.shape[3]
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
q_block_offset += chunk_q_blocks
return combined_attn_sum, combined_mask
def test_single_qkv(qkv_path):
"""Test a single QKV file."""
data = load_qkv(qkv_path)
query = data["query"].cuda().to(torch.bfloat16)
key = data["key"].cuda().to(torch.bfloat16)
seq_len = query.shape[2]
print(f"\nTesting with seq_len={seq_len}")
print("=" * 60)
# Run standard xattn_estimate
print("[1] Running standard xattn_estimate...")
try:
attn_sum_std, mask_std = xattn_estimate(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
use_triton=True,
)
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
return False
# Run chunked xattn_estimate with EXTERNAL chunking
print("[2] Running chunked xattn_estimate (external chunking)...")
try:
attn_sum_chunked, mask_chunked = run_chunked_externally(
query, key,
q_start_pos=0,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
)
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
return False
# Compare results
print("[3] Comparing results...")
chunked_q_blocks = mask_chunked.shape[2]
chunked_k_blocks = mask_chunked.shape[3]
# Extract comparable region from standard mask
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
# Compare masks
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
# Compare attn_sums
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
print(f" Attn sum max diff: {attn_diff:.6f}")
else:
print(f" Attn sum shape mismatch")
# Clean up GPU memory
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
torch.cuda.empty_cache()
return masks_match
# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
help="Directory containing QKV files")
args = parser.parse_args()
# QKV files to test
qkv_files = [
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
]
available_files = [p for p in qkv_files if os.path.exists(p)]
if not available_files:
print(f"No QKV file found in {args.qkv_dir}.")
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
sys.exit(1)
print(f"Found {len(available_files)} QKV files to test")
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
print(f"Using Triton kernels")
all_passed = True
results = []
for qkv_path in available_files:
passed = test_single_qkv(qkv_path)
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
results.append((seq_len, passed))
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for seq_len, passed in results:
status = "PASSED" if passed else "FAILED"
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
print("=" * 60)
if all_passed:
print("test_xattn_chunked: PASSED")
sys.exit(0)
else:
print("test_xattn_chunked: FAILED")
sys.exit(1)

View File

@@ -6,9 +6,10 @@ Test: XAttention Triton kernels
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
数据流:
Q, K [batch, heads, seq_len, head_dim]
Q [batch, heads, q_len, head_dim]
K [batch, heads, kv_len, head_dim]
↓ flat_group_gemm_fuse_reshape
attn_scores [batch, heads, seq_len/stride, seq_len/stride]
attn_scores [batch, heads, q_len/stride, kv_len/stride]
↓ softmax_fuse_block_sum
block_sums [batch, heads, q_blocks, k_blocks]
"""
@@ -21,7 +22,11 @@ from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_
# 参数配置
# ============================================================
seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M = 4 * 128 = 512
# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N
# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512
# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256
q_len = 512
kv_len = 2048
head_dim = 128
stride = 4
block_size = 128 # softmax block size (in reshaped space)
@@ -31,26 +36,56 @@ segment_size = 128 # Triton kernel 要求 segment_size >= block_size
# 构造输入: 偶数位置=1, 奇数位置=2
# ============================================================
Q = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda()
K = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda()
for i in range(seq_len):
Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda()
K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
for i in range(q_len):
if i % 2 == 0:
Q[0, 0, i, :] = 1
K[0, 0, i, :] = 1
else:
Q[0, 0, i, :] = 2
for i in range(kv_len):
if i % 2 == 0:
K[0, 0, i, :] = 1
else:
K[0, 0, i, :] = 2
# ============================================================
# Step 1: flat_group_gemm_fuse_reshape
# Step 1: flat_group_gemm_fuse_reshape (chunked along K)
# ============================================================
attn_scores = flat_group_gemm_fuse_reshape(
Q, K, stride,
chunk_start=0,
chunk_end=seq_len // stride,
is_causal=False
)
q_reshaped_len = q_len // stride # 128
kv_reshaped_len = kv_len // stride # 512
# 将 K 沿着长度维度分成多个 chunk
k_chunk_size = 512 # 每个 chunk 512 tokens
num_k_chunks = kv_len // k_chunk_size # 4 chunks
attn_scores_list = []
for k_chunk_idx in range(num_k_chunks):
k_start = k_chunk_idx * k_chunk_size
k_end = k_start + k_chunk_size
K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim]
# 对每个 K chunk 调用 flat_group_gemm_fuse_reshape
# 输出: [batch, heads, q_len/stride, k_chunk_size/stride]
attn_chunk = flat_group_gemm_fuse_reshape(
Q, K_chunk, stride,
chunk_start=0,
chunk_end=q_reshaped_len,
is_causal=False
)
attn_scores_list.append(attn_chunk)
# 拼接所有 K chunks 的结果
# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride]
# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len]
attn_scores = torch.cat(attn_scores_list, dim=-1)
# 验证 shape: [batch, heads, q_len/stride, kv_len/stride]
assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \
f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})"
# 验证: 反对角线求和
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
@@ -63,7 +98,6 @@ assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expect
# Step 2: softmax_fuse_block_sum
# ============================================================
reshaped_len = seq_len // stride
scale = 1.4426950408889634 # log2(e) for exp2
block_sums = softmax_fuse_block_sum(
@@ -71,15 +105,24 @@ block_sums = softmax_fuse_block_sum(
block_size,
segment_size,
chunk_start=0,
chunk_end=reshaped_len,
real_q_len=reshaped_len,
chunk_end=q_reshaped_len,
real_q_len=q_reshaped_len,
scale=scale,
is_causal=False
)
# 验证 shape: [batch, heads, q_blocks, k_blocks]
q_blocks = q_reshaped_len // block_size # 128 / 128 = 1
k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4
assert block_sums.shape == (1, 1, q_blocks, k_blocks), \
f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})"
# 验证: 每个 block 的 softmax 结果求和
# 所有 attn_scores 相同 → softmax 均匀分布 → block_sum = block_size^2 / reshaped_len
expected_sum = block_size * block_size / reshaped_len
# 所有 attn_scores 相同 → softmax 均匀分布
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len
# 每个 Q block 有 block_size 行
# block_sum = block_size * (block_size / kv_reshaped_len)
expected_sum = block_size * block_size / kv_reshaped_len
actual_sum = block_sums[0, 0, 0, 0].item()
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"