[WIP] Before refactor the compute)_chunked_prefill.
This commit is contained in:
@@ -48,7 +48,7 @@ class Config:
|
||||
# XAttention BSA specific parameters
|
||||
sparse_block_size: int = 128 # Block size for BSA (tokens per block)
|
||||
sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation
|
||||
sparse_threshold: float = 0.9 # Cumulative attention threshold (0-1)
|
||||
sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention)
|
||||
sparse_use_triton: bool = True # Use Triton kernels for estimation
|
||||
sparse_stride: int = 8 # Stride for Q/K downsampling
|
||||
|
||||
|
||||
@@ -124,42 +124,6 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
"""
|
||||
return available_blocks
|
||||
|
||||
def _load_all_historical_kv(
|
||||
self,
|
||||
cpu_block_table: List[int],
|
||||
layer_id: int,
|
||||
offload_engine: "OffloadEngine",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Load all historical K/V from CPU to GPU.
|
||||
|
||||
Args:
|
||||
cpu_block_table: List of CPU block IDs
|
||||
layer_id: Current layer index
|
||||
offload_engine: OffloadEngine instance
|
||||
|
||||
Returns:
|
||||
(k_hist, v_hist) with shape [total_tokens, kv_heads, head_dim]
|
||||
"""
|
||||
if not cpu_block_table:
|
||||
return None, None
|
||||
|
||||
k_list = []
|
||||
v_list = []
|
||||
|
||||
for cpu_block_id in cpu_block_table:
|
||||
k_block, v_block = offload_engine.load_block_full_from_cpu(
|
||||
cpu_block_id, layer_id
|
||||
)
|
||||
k_list.append(k_block)
|
||||
v_list.append(v_block)
|
||||
|
||||
# Concatenate: [num_blocks, block_size, kv_heads, head_dim] -> [total_tokens, kv_heads, head_dim]
|
||||
k_hist = torch.cat(k_list, dim=0)
|
||||
v_hist = torch.cat(v_list, dim=0)
|
||||
|
||||
return k_hist, v_hist
|
||||
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -258,8 +258,12 @@ class Attention(nn.Module):
|
||||
raise RuntimeError("sparse_policy is required for chunked decode")
|
||||
|
||||
# Check if policy supports decode phase
|
||||
# If not, fallback to FullAttentionPolicy (e.g., XAttentionBSAPolicy only supports prefill)
|
||||
if not sparse_policy.supports_decode:
|
||||
raise RuntimeError(f"{sparse_policy} does not support decode phase")
|
||||
from nanovllm.kvcache.sparse import FullAttentionPolicy
|
||||
sparse_policy = FullAttentionPolicy()
|
||||
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
|
||||
f"falling back to FullAttentionPolicy")
|
||||
|
||||
# [DEBUG] Verify execution path
|
||||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
|
||||
|
||||
310
task_plan.md
310
task_plan.md
@@ -1,90 +1,286 @@
|
||||
# Task Plan: XAttention BSA 集成到 nanovllm
|
||||
# Task Plan: XAttention BSA 真正的 Sparse 实现
|
||||
|
||||
## Goal
|
||||
|
||||
使用 `--sparse-policy XATTN_BSA` 运行 `test_ruler.py`,通过 `niah_single_1` 的前 5 个 sample。
|
||||
实现 XAttentionBSAPolicy 的真正 sparse attention,在 `select_blocks` 中使用 `xattn_estimate_chunked` 选择重要的 blocks,然后复用 FullAttentionPolicy 的 ring buffer pipeline。
|
||||
|
||||
**验收标准**:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=X PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_ruler.py \
|
||||
--model ~/models/Llama-3.1-8B-Instruct \
|
||||
--enable-offload \
|
||||
--sparse-policy XATTN_BSA \
|
||||
--task niah_single_1 \
|
||||
--sample-ids 0,1,2,3,4
|
||||
# 期望: 5/5 PASS
|
||||
--datasets niah_single_1 \
|
||||
--sample-indices 0,1,2,3,4
|
||||
# 期望: 5/5 PASS,并且真正使用 sparse selection
|
||||
```
|
||||
|
||||
## 当前状态
|
||||
## 当前状态: Phase 1 - 代码分析完成
|
||||
|
||||
- `XAttentionBSAPolicy.compute_chunked_prefill` 实现 = `FullAttentionPolicy`(无 sparse)
|
||||
- `xattn_estimate_chunked` 已实现并验证
|
||||
- BSA kernel (`block_sparse_attn`) 可用
|
||||
## 核心设计理解
|
||||
|
||||
### 1. Block Size 关系
|
||||
|
||||
| 参数 | 值 | 说明 |
|
||||
|------|-----|------|
|
||||
| BSA block_size | 128 tokens | XAttention 的 block 粒度 |
|
||||
| kvcache_block_size | 1024 tokens | CPU offload 的 block 粒度 |
|
||||
| 比例 | 1:8 | 1 CPU block = 8 BSA blocks |
|
||||
|
||||
### 2. 特化条件(用户要求)
|
||||
|
||||
- BSA chunk_size = 外部 chunk_size
|
||||
- 这样 `xattn_estimate_chunked` 返回的 mask 可以直接映射到 CPU block selection
|
||||
- 复用现有的 `flash_attn_with_lse` + `merge_attention_outputs`
|
||||
|
||||
### 3. select_blocks 设计
|
||||
|
||||
```
|
||||
select_blocks(available_blocks, offload_engine, ctx) -> List[int]
|
||||
│
|
||||
├─ 1. 从 metadata cache 获取下采样的 K
|
||||
│ (在 on_prefill_offload 中收集)
|
||||
│
|
||||
├─ 2. 调用 xattn_estimate_chunked(Q, K_downsampled, q_start_pos)
|
||||
│ 返回 mask: [B, H, q_blocks, k_blocks]
|
||||
│
|
||||
├─ 3. 将 BSA k_blocks 映射到 CPU block IDs
|
||||
│ 每 8 个 BSA blocks = 1 CPU block
|
||||
│ 只要 8 个中有任意一个被选中,就保留该 CPU block
|
||||
│
|
||||
└─ 4. 返回 selected_cpu_blocks
|
||||
```
|
||||
|
||||
### 4. Metadata 存储策略
|
||||
|
||||
**方案 A**: 存储下采样的 K(内存友好)
|
||||
```python
|
||||
# on_prefill_offload 中:
|
||||
k_downsampled = k_cache[::stride] # [block_size/stride, H, D]
|
||||
self._k_cache[layer_id][cpu_block_id] = k_downsampled
|
||||
```
|
||||
|
||||
**内存计算** (stride=8):
|
||||
- 每 block: (1024/8) * 8 * 128 * 2 bytes = 256 KB
|
||||
- 256 blocks * 32 layers = 2 GB (GPU 上用于快速估计)
|
||||
|
||||
**方案 B**: 存储 min/max metadata (更省内存)
|
||||
```python
|
||||
# on_prefill_offload 中:
|
||||
k_min = k_cache[:num_valid].min(dim=0).values # [H, D]
|
||||
k_max = k_cache[:num_valid].max(dim=0).values # [H, D]
|
||||
```
|
||||
- 但这需要不同的估计算法,不能直接用 xattn_estimate
|
||||
|
||||
**决定**: 使用方案 A(下采样 K),因为可以直接复用 xattn_estimate_chunked
|
||||
|
||||
## Phases
|
||||
|
||||
- [ ] Phase 1: 理解当前代码路径
|
||||
- [ ] Phase 2: 实现 sparse mask 估计
|
||||
- [ ] Phase 3: 实现 BSA sparse 计算
|
||||
- [ ] Phase 4: 测试验证
|
||||
- [x] Phase 1: 代码分析,理解当前实现
|
||||
- [ ] Phase 2: 实现 on_prefill_offload 收集 K metadata
|
||||
- [ ] Phase 3: 实现 select_blocks 中的 xattn estimation
|
||||
- [ ] Phase 4: 实现 BSA block → CPU block 的映射
|
||||
- [ ] Phase 5: 测试验证
|
||||
|
||||
## Phase 1: 理解当前代码路径
|
||||
## Phase 2: on_prefill_offload 实现
|
||||
|
||||
### 1.1 确认 XATTN_BSA policy 是否被正确加载
|
||||
- [ ] 检查 `test_ruler.py` 如何解析 `--sparse-policy XATTN_BSA`
|
||||
- [ ] 检查 `KVCacheManager` 如何实例化 sparse_policy
|
||||
- [ ] 运行 baseline 测试(`--sparse-policy FULL`)确认基础功能正常
|
||||
### 需要修改的文件
|
||||
- `nanovllm/kvcache/sparse/xattn_bsa.py`
|
||||
|
||||
### 1.2 确认数据流
|
||||
- [ ] `compute_chunked_prefill` 的输入参数含义
|
||||
- [ ] `offload_engine` 提供的数据访问接口
|
||||
- [ ] 当前 chunk 的 K/V 如何获取
|
||||
### 实现细节
|
||||
|
||||
## Phase 2: 实现 sparse mask 估计
|
||||
```python
|
||||
class XAttentionBSAPolicy(SparsePolicy):
|
||||
def __init__(self, threshold=0.9, stride=8, ...):
|
||||
self.threshold = threshold
|
||||
self.stride = stride
|
||||
self._k_cache: Dict[int, Dict[int, torch.Tensor]] = {}
|
||||
# _k_cache[layer_id][cpu_block_id] = k_downsampled
|
||||
|
||||
### 2.1 调用 xattn_estimate_chunked
|
||||
- [ ] 在 `compute_chunked_prefill` 中加载历史 K
|
||||
- [ ] 拼接历史 K + 当前 K
|
||||
- [ ] 调用 `xattn_estimate_chunked(q, k_full, q_start_pos=...)`
|
||||
- [ ] 获取 block mask
|
||||
def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device):
|
||||
"""初始化 K cache 结构"""
|
||||
self._k_cache = {layer_id: {} for layer_id in range(num_layers)}
|
||||
self._num_kv_heads = num_kv_heads
|
||||
self._head_dim = head_dim
|
||||
|
||||
### 2.2 处理参数对齐
|
||||
- [ ] BSA block_size = 128
|
||||
- [ ] chunk_size 与 kvcache_block_size 的关系
|
||||
- [ ] q_start_pos 计算
|
||||
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
|
||||
"""收集下采样的 K 用于后续估计"""
|
||||
# k_cache: [block_size, num_kv_heads, head_dim]
|
||||
k_downsampled = k_cache[:num_valid_tokens:self.stride].clone()
|
||||
# k_downsampled: [num_valid_tokens//stride, num_kv_heads, head_dim]
|
||||
self._k_cache[layer_id][cpu_block_id] = k_downsampled
|
||||
```
|
||||
|
||||
## Phase 3: 实现 BSA sparse 计算
|
||||
## Phase 3: select_blocks 实现
|
||||
|
||||
### 3.1 方案选择
|
||||
- 选项 A: 历史 + 当前分开计算,然后 merge
|
||||
- 选项 B: 全部一起用 BSA 计算
|
||||
### 关键问题
|
||||
|
||||
### 3.2 实现
|
||||
- [ ] 构造 BSA 需要的输入格式
|
||||
- [ ] 调用 `block_sparse_attn_func`
|
||||
- [ ] 处理输出格式
|
||||
1. **Q 从哪里来?**
|
||||
- `ctx.query` 需要在调用 select_blocks 时传入
|
||||
- 当前 FullAttentionPolicy 传递 `query=None`
|
||||
- 需要修改 compute_chunked_prefill 传递真实的 Q
|
||||
|
||||
## Phase 4: 测试验证
|
||||
2. **Q 的格式转换**
|
||||
- 输入 Q: [seq_len, num_heads, head_dim]
|
||||
- xattn 需要: [B, H, q_len, D]
|
||||
- 转换: `q.unsqueeze(0).transpose(1, 2)`
|
||||
|
||||
### 4.1 单元测试
|
||||
- [ ] 验证 sparse mask 与 `test_xattn_estimate_chunked.py` 一致
|
||||
3. **K 的组装**
|
||||
- 从 `_k_cache[layer_id]` 获取各 block 的下采样 K
|
||||
- 按 `available_blocks` 顺序 cat 起来
|
||||
- 结果: [B, H, total_k_downsampled, D]
|
||||
|
||||
### 4.2 集成测试
|
||||
- [ ] 运行验收命令
|
||||
- [ ] 5/5 PASS
|
||||
### 实现草案
|
||||
|
||||
## Key Questions
|
||||
```python
|
||||
def select_blocks(self, available_blocks, offload_engine, ctx):
|
||||
if not available_blocks or ctx.query is None:
|
||||
return available_blocks
|
||||
|
||||
1. 历史 K 如何高效加载?(全量 vs 按需)
|
||||
2. BSA causal mask 如何处理?(历史 non-causal + 当前 causal)
|
||||
layer_id = ctx.layer_id
|
||||
|
||||
# 1. 组装下采样的 K
|
||||
k_list = []
|
||||
for cpu_block_id in available_blocks:
|
||||
if cpu_block_id in self._k_cache[layer_id]:
|
||||
k_list.append(self._k_cache[layer_id][cpu_block_id])
|
||||
|
||||
if not k_list:
|
||||
return available_blocks
|
||||
|
||||
k_hist = torch.cat(k_list, dim=0) # [total_tokens/stride, H, D]
|
||||
k_hist = k_hist.unsqueeze(0).transpose(1, 2) # [1, H, k_len, D]
|
||||
|
||||
# 2. 准备 Q
|
||||
q = ctx.query # [seq_len, num_heads, head_dim]
|
||||
q = q.unsqueeze(0).transpose(1, 2) # [1, H, q_len, D]
|
||||
|
||||
# GQA 扩展(如果需要)
|
||||
if q.shape[1] != k_hist.shape[1]:
|
||||
num_groups = q.shape[1] // k_hist.shape[1]
|
||||
k_hist = k_hist.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
# 3. 计算 q_start_pos
|
||||
q_start_pos = len(available_blocks) * ctx.block_size
|
||||
|
||||
# 4. 调用 xattn_estimate_chunked
|
||||
# 注意:K 已经是下采样的,需要调整参数
|
||||
attn_sum, mask = xattn_estimate_chunked(
|
||||
q, k_hist,
|
||||
q_start_pos=q_start_pos // self.stride, # 调整到下采样空间
|
||||
block_size=self.BSA_BLOCK_SIZE // self.stride, # 16
|
||||
stride=1, # K 已经下采样
|
||||
threshold=self.threshold,
|
||||
chunk_size=q.shape[2], # 与 Q 长度一致
|
||||
use_triton=self.use_triton,
|
||||
)
|
||||
|
||||
# 5. 从 mask 提取 CPU block IDs
|
||||
# mask: [1, H, q_blocks, k_blocks]
|
||||
# 对所有 heads 取 OR
|
||||
selected_mask = mask.any(dim=1).squeeze(0) # [q_blocks, k_blocks]
|
||||
# 对所有 q_blocks 取 OR(只要任意 Q 位置需要这个 K block)
|
||||
selected_k_mask = selected_mask.any(dim=0) # [k_blocks]
|
||||
|
||||
# 6. 映射 BSA blocks → CPU blocks
|
||||
# 每个 CPU block = 8 BSA blocks (block_size=1024, BSA_block=128)
|
||||
bsa_to_cpu_ratio = ctx.block_size // self.BSA_BLOCK_SIZE # 8
|
||||
num_cpu_blocks = len(available_blocks)
|
||||
|
||||
selected_cpu_indices = set()
|
||||
for bsa_idx in selected_k_mask.nonzero(as_tuple=True)[0].tolist():
|
||||
cpu_idx = bsa_idx // bsa_to_cpu_ratio
|
||||
if cpu_idx < num_cpu_blocks:
|
||||
selected_cpu_indices.add(cpu_idx)
|
||||
|
||||
selected_blocks = [available_blocks[i] for i in sorted(selected_cpu_indices)]
|
||||
|
||||
logger.info(f"[XAttn] select_blocks: {len(available_blocks)} -> {len(selected_blocks)} "
|
||||
f"({100*len(selected_blocks)/len(available_blocks):.1f}%)")
|
||||
|
||||
return selected_blocks
|
||||
```
|
||||
|
||||
## Phase 4: compute_chunked_prefill
|
||||
|
||||
### 关键修改
|
||||
|
||||
1. **传递真实的 Q 给 select_blocks**
|
||||
- 修改 PolicyContext 构造,设置 `query=q`
|
||||
|
||||
2. **复用 FullAttentionPolicy 的 pipeline**
|
||||
- 继承 FullAttentionPolicy 而不是 SparsePolicy
|
||||
- 或者直接调用父类方法
|
||||
|
||||
### 方案对比
|
||||
|
||||
**方案 A**: XAttentionBSAPolicy 继承 FullAttentionPolicy
|
||||
```python
|
||||
class XAttentionBSAPolicy(FullAttentionPolicy):
|
||||
# 只需要 override select_blocks 和 on_prefill_offload
|
||||
# compute_chunked_prefill 直接用父类的
|
||||
```
|
||||
|
||||
**方案 B**: 独立实现,调用相同的 pipeline 代码
|
||||
```python
|
||||
class XAttentionBSAPolicy(SparsePolicy):
|
||||
def compute_chunked_prefill(self, q, k, v, ...):
|
||||
# 复制 FullAttentionPolicy 的代码
|
||||
# 但修改 PolicyContext 传递 query=q
|
||||
```
|
||||
|
||||
**决定**: 使用方案 B,因为需要在 compute_chunked_prefill 中修改 PolicyContext
|
||||
|
||||
## Phase 5: 测试
|
||||
|
||||
### 单元测试
|
||||
```bash
|
||||
# 测试 select_blocks 的 sparsity
|
||||
python -c "
|
||||
from nanovllm.kvcache.sparse.xattn_bsa import XAttentionBSAPolicy
|
||||
policy = XAttentionBSAPolicy(threshold=0.9)
|
||||
# ... 测试代码
|
||||
"
|
||||
```
|
||||
|
||||
### 集成测试
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_ruler.py \
|
||||
--model ~/models/Llama-3.1-8B-Instruct \
|
||||
--enable-offload \
|
||||
--sparse-policy XATTN_BSA \
|
||||
--datasets niah_single_1 \
|
||||
--sample-indices 0,1,2,3,4
|
||||
```
|
||||
|
||||
## Key Decisions
|
||||
|
||||
| 决策 | 理由 |
|
||||
|------|------|
|
||||
| 使用下采样 K 作为 metadata | 可以直接复用 xattn_estimate_chunked |
|
||||
| stride=8 | 平衡内存和精度 |
|
||||
| BSA blocks → CPU blocks 映射用 OR | 只要有一个 BSA block 被选中就保留 |
|
||||
| 继承 FullAttentionPolicy 的 pipeline | 复用已验证的 ring buffer 流程 |
|
||||
|
||||
## Files to Modify
|
||||
|
||||
| 文件 | 修改 |
|
||||
|------|------|
|
||||
| `nanovllm/kvcache/sparse/xattn_bsa.py` | 主要实现:initialize, on_prefill_offload, select_blocks |
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **GQA 处理**: Llama-3.1-8B 有 32 query heads, 8 kv heads,需要在估计时扩展 K
|
||||
2. **内存管理**: `_k_cache` 存储在 GPU,需要在 reset() 时清理
|
||||
3. **Triton 兼容性**: xattn_estimate_chunked 有 Triton bug,可能需要用 PyTorch fallback
|
||||
4. **边界条件**: 第一个 chunk (available_blocks=[]) 时直接返回空列表
|
||||
|
||||
## Errors Encountered
|
||||
|
||||
(待填充)
|
||||
|
||||
## Status
|
||||
|
||||
**Currently in Phase 1** - 等待用户确认后开始
|
||||
|
||||
## 待讨论
|
||||
|
||||
请确认:
|
||||
1. 这个 goal 和验收标准是否正确?
|
||||
2. 我使用哪个 GPU 运行测试?
|
||||
**Currently in Phase 1** - 代码分析完成,准备开始 Phase 2 实现
|
||||
|
||||
334
tests/test_xattn_bsa.py
Normal file
334
tests/test_xattn_bsa.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
Test XAttention + BSA with RULER benchmark data.
|
||||
|
||||
Tests XAttention sparse attention correctness using RULER NIAH task.
|
||||
|
||||
Attention methods:
|
||||
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
|
||||
- Decode: FlashAttention (always, since q_len=1)
|
||||
|
||||
Usage (in compass conda env with BSA available):
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
|
||||
|
||||
# Test with XAttention + BSA for prefill (default)
|
||||
python tests/test_xattn_bsa.py --prefill-method xattn
|
||||
|
||||
# Test with FlashAttention for prefill (baseline)
|
||||
python tests/test_xattn_bsa.py --prefill-method flash
|
||||
|
||||
# Test specific sample(s)
|
||||
python tests/test_xattn_bsa.py --sample-id 0
|
||||
python tests/test_xattn_bsa.py --sample-ids 0,1,2
|
||||
|
||||
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
|
||||
and new `past_key_values` API).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
|
||||
# ============================================================
|
||||
# XAttention + BSA Functions
|
||||
# ============================================================
|
||||
|
||||
def expand_kv_for_gqa(key_states, value_states, num_heads):
|
||||
"""Expand KV for Grouped Query Attention."""
|
||||
num_kv_heads = key_states.shape[1]
|
||||
if num_heads == num_kv_heads:
|
||||
return key_states, value_states
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
|
||||
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
|
||||
"""Standard FlashAttention."""
|
||||
from flash_attn import flash_attn_func
|
||||
q = query_states.transpose(1, 2)
|
||||
k = key_states.transpose(1, 2)
|
||||
v = value_states.transpose(1, 2)
|
||||
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
|
||||
|
||||
|
||||
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
|
||||
"""XAttention + BSA sparse attention."""
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
|
||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||
k_len = key_states.shape[2]
|
||||
|
||||
_, mask = xattn_estimate(
|
||||
query_states, key_states,
|
||||
chunk_size=16384, block_size=128, threshold=threshold,
|
||||
use_triton=True, causal=True,
|
||||
)
|
||||
|
||||
q_block_num = (q_len + 127) // 128
|
||||
k_block_num = (k_len + 127) // 128
|
||||
|
||||
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
|
||||
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
||||
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
||||
|
||||
__import__('pdb').set_trace()
|
||||
|
||||
output = block_sparse_attn_func(
|
||||
q, k, v,
|
||||
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
|
||||
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
|
||||
torch.ones(num_heads, dtype=torch.int32, device=q.device),
|
||||
None,
|
||||
mask[:, :, :q_block_num, :k_block_num].contiguous(),
|
||||
q_len, k_len,
|
||||
p_dropout=0.0, deterministic=True, is_causal=True,
|
||||
)
|
||||
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
DEBUG = False # Set to True to enable debugging
|
||||
|
||||
def create_patched_forward(prefill_method="xattn", threshold=0.9):
|
||||
"""Create patched forward with configurable prefill method.
|
||||
|
||||
Args:
|
||||
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
|
||||
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
|
||||
|
||||
Note:
|
||||
- Prefill (q_len > 1): Uses specified prefill_method
|
||||
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
|
||||
"""
|
||||
call_count = [0] # Mutable to track calls across layers
|
||||
|
||||
def patched_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
position_embeddings=None,
|
||||
attention_mask=None,
|
||||
past_key_value=None, # Old API (transformers < 4.57)
|
||||
past_key_values=None, # New API (transformers >= 4.57)
|
||||
cache_position=None,
|
||||
**kwargs
|
||||
):
|
||||
# Handle both old and new transformers API
|
||||
kv_cache = past_key_values if past_key_values is not None else past_key_value
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
num_heads = self.config.num_attention_heads
|
||||
num_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
|
||||
# Compute Q, K, V projections
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary position embedding
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Handle KV cache
|
||||
if kv_cache is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = kv_cache.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# Expand KV for GQA
|
||||
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
|
||||
|
||||
# Debug output
|
||||
if DEBUG and self.layer_idx == 0:
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 5:
|
||||
phase = "prefill" if q_len > 1 else "decode"
|
||||
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
|
||||
print(f" kv_cache is None: {kv_cache is None}")
|
||||
|
||||
# Choose attention method:
|
||||
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
|
||||
# - Decode (q_len = 1): Always use FlashAttention
|
||||
is_prefill = q_len > 1
|
||||
|
||||
if is_prefill and prefill_method == "xattn":
|
||||
# Prefill with XAttention + BSA (sparse)
|
||||
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
|
||||
else:
|
||||
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
|
||||
# Note: For decode (q_len=1), causal=False since single query attends to all KV
|
||||
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
|
||||
|
||||
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
|
||||
return attn_output, None
|
||||
|
||||
return patched_forward
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Data & Evaluation
|
||||
# ============================================================
|
||||
|
||||
def load_samples(filepath, indices=None):
|
||||
"""Load samples from JSONL file."""
|
||||
samples = []
|
||||
with open(filepath) as f:
|
||||
for i, line in enumerate(f):
|
||||
if indices is None or i in indices:
|
||||
sample = json.loads(line)
|
||||
sample["_idx"] = i
|
||||
samples.append(sample)
|
||||
return samples
|
||||
|
||||
|
||||
def string_match_all(output_text, expected_list):
|
||||
"""RULER metric: fraction of expected values found in output."""
|
||||
output_lower = output_text.lower().replace('\n', ' ')
|
||||
if not expected_list:
|
||||
return 1.0
|
||||
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Test
|
||||
# ============================================================
|
||||
|
||||
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
|
||||
"""Test attention methods using RULER data.
|
||||
|
||||
Args:
|
||||
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
|
||||
"""
|
||||
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
|
||||
|
||||
print("=" * 60)
|
||||
print("RULER NIAH Attention Test")
|
||||
print("=" * 60)
|
||||
print(f"Data: {data_file}")
|
||||
print(f"Samples: {sample_ids}")
|
||||
print(f"Prefill method: {prefill_desc}")
|
||||
print(f"Decode method: FlashAttention (always)")
|
||||
if prefill_method == "xattn":
|
||||
print(f"XAttention threshold: {threshold}")
|
||||
|
||||
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
|
||||
if not samples:
|
||||
print("No samples found!")
|
||||
return False
|
||||
print(f"Loaded {len(samples)} samples")
|
||||
|
||||
# Load model
|
||||
print(f"\nLoading model: {model_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map="cuda",
|
||||
attn_implementation="eager", # Will be patched
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# Patch all layers
|
||||
print(f"Patching attention layers...")
|
||||
print(f" - Prefill: {prefill_desc}")
|
||||
print(f" - Decode: FlashAttention")
|
||||
for idx, layer in enumerate(model.model.layers):
|
||||
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
|
||||
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
|
||||
layer.self_attn, type(layer.self_attn)
|
||||
)
|
||||
|
||||
total_score = 0.0
|
||||
results = []
|
||||
|
||||
for sample in samples:
|
||||
idx = sample["_idx"]
|
||||
prompt = sample["input"]
|
||||
expected = sample["outputs"]
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
num_tokens = inputs["input_ids"].shape[1]
|
||||
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
|
||||
print(f"Expected: {expected}")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model.generate(
|
||||
inputs["input_ids"],
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
|
||||
score = string_match_all(output_text, expected)
|
||||
total_score += score
|
||||
|
||||
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
|
||||
print(f"Output: '{output_text[:100]}...'")
|
||||
print(f"Result: {status} (score={score:.2f})")
|
||||
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
|
||||
|
||||
avg_score = total_score / len(samples)
|
||||
passed = sum(1 for r in results if r["passed"])
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return avg_score >= 0.5
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
|
||||
)
|
||||
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
|
||||
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
|
||||
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
|
||||
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
|
||||
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
|
||||
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
|
||||
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=50)
|
||||
# Keep old option for backwards compatibility
|
||||
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model.replace("~", "/home/zijie")
|
||||
|
||||
# Handle deprecated --no-xattn option
|
||||
prefill_method = args.prefill_method
|
||||
if args.no_xattn:
|
||||
prefill_method = "flash"
|
||||
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
|
||||
|
||||
if args.sample_id is not None:
|
||||
sample_ids = [args.sample_id]
|
||||
elif args.sample_ids:
|
||||
sample_ids = [int(x) for x in args.sample_ids.split(",")]
|
||||
else:
|
||||
sample_ids = [0]
|
||||
|
||||
# Check BSA availability if using xattn
|
||||
if prefill_method == "xattn":
|
||||
try:
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
print("✓ BSA (Block Sparse Attention) available")
|
||||
except ImportError:
|
||||
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
|
||||
sys.exit(1)
|
||||
|
||||
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
|
||||
print("\ntest_xattn_bsa: PASSED")
|
||||
else:
|
||||
print("\ntest_xattn_bsa: FAILED")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
259
tests/test_xattn_chunked.py
Normal file
259
tests/test_xattn_chunked.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
|
||||
|
||||
Uses real QKV data captured from model inference.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
BLOCK_SIZE = 64
|
||||
STRIDE = 4
|
||||
THRESHOLD = 0.9
|
||||
CHUNK_SIZE = 4096
|
||||
|
||||
# Default QKV data directory (relative to project root)
|
||||
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# ============================================================
|
||||
|
||||
def load_qkv(path):
|
||||
"""Load saved QKV data."""
|
||||
data = torch.load(path, map_location="cpu", weights_only=False)
|
||||
print(f"Loaded: {path}")
|
||||
print(f" Query shape: {data['query'].shape}")
|
||||
print(f" Key shape: {data['key'].shape}")
|
||||
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
|
||||
return data
|
||||
|
||||
|
||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||
"""Compare two masks and report differences."""
|
||||
if mask1.shape != mask2.shape:
|
||||
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||
return False
|
||||
|
||||
diff = (mask1 != mask2).sum().item()
|
||||
total = mask1.numel()
|
||||
match_rate = (total - diff) / total * 100
|
||||
|
||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||
|
||||
if diff > 0:
|
||||
diff_indices = torch.where(mask1 != mask2)
|
||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||
|
||||
return diff == 0
|
||||
|
||||
|
||||
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
|
||||
"""
|
||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||
This simulates how chunked prefill should be used in practice.
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query.shape
|
||||
_, _, k_len, _ = key.shape
|
||||
|
||||
q_block_num = (q_len + block_size - 1) // block_size
|
||||
k_block_num = (k_len + block_size - 1) // block_size
|
||||
|
||||
# If Q fits in one chunk, call directly
|
||||
if q_len <= chunk_size:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return xattn_estimate_chunked(
|
||||
query, key,
|
||||
q_start_pos=q_start_pos,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# External chunking: split Q and call for each chunk
|
||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||
print(f" External chunking: {num_q_chunks} chunks")
|
||||
|
||||
combined_attn_sum = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=query.dtype, device=query.device
|
||||
)
|
||||
combined_mask = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=torch.bool, device=query.device
|
||||
)
|
||||
|
||||
q_block_offset = 0
|
||||
for q_chunk_idx in range(num_q_chunks):
|
||||
q_chunk_start = q_chunk_idx * chunk_size
|
||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||
|
||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||
|
||||
# For causal attention, K accumulates up to current Q position
|
||||
k_end = q_start_pos + q_chunk_end
|
||||
k_chunk = key[:, :, :k_end, :]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||
q_chunk, k_chunk,
|
||||
q_start_pos=q_start_pos + q_chunk_start,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Place chunk results into combined output
|
||||
chunk_q_blocks = mask_chunk.shape[2]
|
||||
chunk_k_blocks = mask_chunk.shape[3]
|
||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||
q_block_offset += chunk_q_blocks
|
||||
|
||||
return combined_attn_sum, combined_mask
|
||||
|
||||
|
||||
def test_single_qkv(qkv_path):
|
||||
"""Test a single QKV file."""
|
||||
data = load_qkv(qkv_path)
|
||||
query = data["query"].cuda().to(torch.bfloat16)
|
||||
key = data["key"].cuda().to(torch.bfloat16)
|
||||
|
||||
seq_len = query.shape[2]
|
||||
print(f"\nTesting with seq_len={seq_len}")
|
||||
print("=" * 60)
|
||||
|
||||
# Run standard xattn_estimate
|
||||
print("[1] Running standard xattn_estimate...")
|
||||
try:
|
||||
attn_sum_std, mask_std = xattn_estimate(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
use_triton=True,
|
||||
)
|
||||
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||
try:
|
||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||
query, key,
|
||||
q_start_pos=0,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
)
|
||||
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Compare results
|
||||
print("[3] Comparing results...")
|
||||
chunked_q_blocks = mask_chunked.shape[2]
|
||||
chunked_k_blocks = mask_chunked.shape[3]
|
||||
|
||||
# Extract comparable region from standard mask
|
||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
|
||||
# Compare masks
|
||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||
|
||||
# Compare attn_sums
|
||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||
else:
|
||||
print(f" Attn sum shape mismatch")
|
||||
|
||||
# Clean up GPU memory
|
||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return masks_match
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
|
||||
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
|
||||
help="Directory containing QKV files")
|
||||
args = parser.parse_args()
|
||||
|
||||
# QKV files to test
|
||||
qkv_files = [
|
||||
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
|
||||
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
|
||||
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
|
||||
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
|
||||
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
|
||||
]
|
||||
|
||||
available_files = [p for p in qkv_files if os.path.exists(p)]
|
||||
|
||||
if not available_files:
|
||||
print(f"No QKV file found in {args.qkv_dir}.")
|
||||
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(available_files)} QKV files to test")
|
||||
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
|
||||
print(f"Using Triton kernels")
|
||||
|
||||
all_passed = True
|
||||
results = []
|
||||
|
||||
for qkv_path in available_files:
|
||||
passed = test_single_qkv(qkv_path)
|
||||
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
|
||||
results.append((seq_len, passed))
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for seq_len, passed in results:
|
||||
status = "PASSED" if passed else "FAILED"
|
||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
|
||||
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("test_xattn_chunked: PASSED")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("test_xattn_chunked: FAILED")
|
||||
sys.exit(1)
|
||||
@@ -6,9 +6,10 @@ Test: XAttention Triton kernels
|
||||
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
|
||||
|
||||
数据流:
|
||||
Q, K [batch, heads, seq_len, head_dim]
|
||||
Q [batch, heads, q_len, head_dim]
|
||||
K [batch, heads, kv_len, head_dim]
|
||||
↓ flat_group_gemm_fuse_reshape
|
||||
attn_scores [batch, heads, seq_len/stride, seq_len/stride]
|
||||
attn_scores [batch, heads, q_len/stride, kv_len/stride]
|
||||
↓ softmax_fuse_block_sum
|
||||
block_sums [batch, heads, q_blocks, k_blocks]
|
||||
"""
|
||||
@@ -21,7 +22,11 @@ from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
|
||||
seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M = 4 * 128 = 512
|
||||
# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N
|
||||
# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512
|
||||
# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256
|
||||
q_len = 512
|
||||
kv_len = 2048
|
||||
head_dim = 128
|
||||
stride = 4
|
||||
block_size = 128 # softmax block size (in reshaped space)
|
||||
@@ -31,26 +36,56 @@ segment_size = 128 # Triton kernel 要求 segment_size >= block_size
|
||||
# 构造输入: 偶数位置=1, 奇数位置=2
|
||||
# ============================================================
|
||||
|
||||
Q = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
K = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
for i in range(seq_len):
|
||||
Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
|
||||
for i in range(q_len):
|
||||
if i % 2 == 0:
|
||||
Q[0, 0, i, :] = 1
|
||||
K[0, 0, i, :] = 1
|
||||
else:
|
||||
Q[0, 0, i, :] = 2
|
||||
|
||||
for i in range(kv_len):
|
||||
if i % 2 == 0:
|
||||
K[0, 0, i, :] = 1
|
||||
else:
|
||||
K[0, 0, i, :] = 2
|
||||
|
||||
# ============================================================
|
||||
# Step 1: flat_group_gemm_fuse_reshape
|
||||
# Step 1: flat_group_gemm_fuse_reshape (chunked along K)
|
||||
# ============================================================
|
||||
|
||||
attn_scores = flat_group_gemm_fuse_reshape(
|
||||
Q, K, stride,
|
||||
chunk_start=0,
|
||||
chunk_end=seq_len // stride,
|
||||
is_causal=False
|
||||
)
|
||||
q_reshaped_len = q_len // stride # 128
|
||||
kv_reshaped_len = kv_len // stride # 512
|
||||
|
||||
# 将 K 沿着长度维度分成多个 chunk
|
||||
k_chunk_size = 512 # 每个 chunk 512 tokens
|
||||
num_k_chunks = kv_len // k_chunk_size # 4 chunks
|
||||
|
||||
attn_scores_list = []
|
||||
for k_chunk_idx in range(num_k_chunks):
|
||||
k_start = k_chunk_idx * k_chunk_size
|
||||
k_end = k_start + k_chunk_size
|
||||
K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim]
|
||||
|
||||
# 对每个 K chunk 调用 flat_group_gemm_fuse_reshape
|
||||
# 输出: [batch, heads, q_len/stride, k_chunk_size/stride]
|
||||
attn_chunk = flat_group_gemm_fuse_reshape(
|
||||
Q, K_chunk, stride,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
is_causal=False
|
||||
)
|
||||
attn_scores_list.append(attn_chunk)
|
||||
|
||||
# 拼接所有 K chunks 的结果
|
||||
# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride]
|
||||
# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len]
|
||||
attn_scores = torch.cat(attn_scores_list, dim=-1)
|
||||
|
||||
# 验证 shape: [batch, heads, q_len/stride, kv_len/stride]
|
||||
assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \
|
||||
f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})"
|
||||
|
||||
# 验证: 反对角线求和
|
||||
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
|
||||
@@ -63,7 +98,6 @@ assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expect
|
||||
# Step 2: softmax_fuse_block_sum
|
||||
# ============================================================
|
||||
|
||||
reshaped_len = seq_len // stride
|
||||
scale = 1.4426950408889634 # log2(e) for exp2
|
||||
|
||||
block_sums = softmax_fuse_block_sum(
|
||||
@@ -71,15 +105,24 @@ block_sums = softmax_fuse_block_sum(
|
||||
block_size,
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=reshaped_len,
|
||||
real_q_len=reshaped_len,
|
||||
chunk_end=q_reshaped_len,
|
||||
real_q_len=q_reshaped_len,
|
||||
scale=scale,
|
||||
is_causal=False
|
||||
)
|
||||
|
||||
# 验证 shape: [batch, heads, q_blocks, k_blocks]
|
||||
q_blocks = q_reshaped_len // block_size # 128 / 128 = 1
|
||||
k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4
|
||||
assert block_sums.shape == (1, 1, q_blocks, k_blocks), \
|
||||
f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})"
|
||||
|
||||
# 验证: 每个 block 的 softmax 结果求和
|
||||
# 所有 attn_scores 相同 → softmax 均匀分布 → block_sum = block_size^2 / reshaped_len
|
||||
expected_sum = block_size * block_size / reshaped_len
|
||||
# 所有 attn_scores 相同 → softmax 均匀分布
|
||||
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len
|
||||
# 每个 Q block 有 block_size 行
|
||||
# block_sum = block_size * (block_size / kv_reshaped_len)
|
||||
expected_sum = block_size * block_size / kv_reshaped_len
|
||||
actual_sum = block_sums[0, 0, 0, 0].item()
|
||||
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user