diff --git a/nanovllm/config.py b/nanovllm/config.py index 23c5200..2766654 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -48,7 +48,7 @@ class Config: # XAttention BSA specific parameters sparse_block_size: int = 128 # Block size for BSA (tokens per block) sparse_samples_per_chunk: int = 128 # Samples per chunk for estimation - sparse_threshold: float = 0.9 # Cumulative attention threshold (0-1) + sparse_threshold: float = 0.95 # Cumulative attention threshold (tau in XAttention) sparse_use_triton: bool = True # Use Triton kernels for estimation sparse_stride: int = 8 # Stride for Q/K downsampling diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index cb4c096..ad1fa2e 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -124,42 +124,6 @@ class XAttentionBSAPolicy(SparsePolicy): """ return available_blocks - def _load_all_historical_kv( - self, - cpu_block_table: List[int], - layer_id: int, - offload_engine: "OffloadEngine", - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Load all historical K/V from CPU to GPU. - - Args: - cpu_block_table: List of CPU block IDs - layer_id: Current layer index - offload_engine: OffloadEngine instance - - Returns: - (k_hist, v_hist) with shape [total_tokens, kv_heads, head_dim] - """ - if not cpu_block_table: - return None, None - - k_list = [] - v_list = [] - - for cpu_block_id in cpu_block_table: - k_block, v_block = offload_engine.load_block_full_from_cpu( - cpu_block_id, layer_id - ) - k_list.append(k_block) - v_list.append(v_block) - - # Concatenate: [num_blocks, block_size, kv_heads, head_dim] -> [total_tokens, kv_heads, head_dim] - k_hist = torch.cat(k_list, dim=0) - v_hist = torch.cat(v_list, dim=0) - - return k_hist, v_hist - def compute_chunked_prefill( self, q: torch.Tensor, diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 700be3a..5a22416 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -258,8 +258,12 @@ class Attention(nn.Module): raise RuntimeError("sparse_policy is required for chunked decode") # Check if policy supports decode phase + # If not, fallback to FullAttentionPolicy (e.g., XAttentionBSAPolicy only supports prefill) if not sparse_policy.supports_decode: - raise RuntimeError(f"{sparse_policy} does not support decode phase") + from nanovllm.kvcache.sparse import FullAttentionPolicy + sparse_policy = FullAttentionPolicy() + logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, " + f"falling back to FullAttentionPolicy") # [DEBUG] Verify execution path logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, " diff --git a/task_plan.md b/task_plan.md index 5255c1a..4441cae 100644 --- a/task_plan.md +++ b/task_plan.md @@ -1,90 +1,286 @@ -# Task Plan: XAttention BSA 集成到 nanovllm +# Task Plan: XAttention BSA 真正的 Sparse 实现 ## Goal -使用 `--sparse-policy XATTN_BSA` 运行 `test_ruler.py`,通过 `niah_single_1` 的前 5 个 sample。 +实现 XAttentionBSAPolicy 的真正 sparse attention,在 `select_blocks` 中使用 `xattn_estimate_chunked` 选择重要的 blocks,然后复用 FullAttentionPolicy 的 ring buffer pipeline。 **验收标准**: ```bash -CUDA_VISIBLE_DEVICES=X PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \ +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \ python tests/test_ruler.py \ --model ~/models/Llama-3.1-8B-Instruct \ --enable-offload \ --sparse-policy XATTN_BSA \ - --task niah_single_1 \ - --sample-ids 0,1,2,3,4 -# 期望: 5/5 PASS + --datasets niah_single_1 \ + --sample-indices 0,1,2,3,4 +# 期望: 5/5 PASS,并且真正使用 sparse selection ``` -## 当前状态 +## 当前状态: Phase 1 - 代码分析完成 -- `XAttentionBSAPolicy.compute_chunked_prefill` 实现 = `FullAttentionPolicy`(无 sparse) -- `xattn_estimate_chunked` 已实现并验证 -- BSA kernel (`block_sparse_attn`) 可用 +## 核心设计理解 + +### 1. Block Size 关系 + +| 参数 | 值 | 说明 | +|------|-----|------| +| BSA block_size | 128 tokens | XAttention 的 block 粒度 | +| kvcache_block_size | 1024 tokens | CPU offload 的 block 粒度 | +| 比例 | 1:8 | 1 CPU block = 8 BSA blocks | + +### 2. 特化条件(用户要求) + +- BSA chunk_size = 外部 chunk_size +- 这样 `xattn_estimate_chunked` 返回的 mask 可以直接映射到 CPU block selection +- 复用现有的 `flash_attn_with_lse` + `merge_attention_outputs` + +### 3. select_blocks 设计 + +``` +select_blocks(available_blocks, offload_engine, ctx) -> List[int] + │ + ├─ 1. 从 metadata cache 获取下采样的 K + │ (在 on_prefill_offload 中收集) + │ + ├─ 2. 调用 xattn_estimate_chunked(Q, K_downsampled, q_start_pos) + │ 返回 mask: [B, H, q_blocks, k_blocks] + │ + ├─ 3. 将 BSA k_blocks 映射到 CPU block IDs + │ 每 8 个 BSA blocks = 1 CPU block + │ 只要 8 个中有任意一个被选中,就保留该 CPU block + │ + └─ 4. 返回 selected_cpu_blocks +``` + +### 4. Metadata 存储策略 + +**方案 A**: 存储下采样的 K(内存友好) +```python +# on_prefill_offload 中: +k_downsampled = k_cache[::stride] # [block_size/stride, H, D] +self._k_cache[layer_id][cpu_block_id] = k_downsampled +``` + +**内存计算** (stride=8): +- 每 block: (1024/8) * 8 * 128 * 2 bytes = 256 KB +- 256 blocks * 32 layers = 2 GB (GPU 上用于快速估计) + +**方案 B**: 存储 min/max metadata (更省内存) +```python +# on_prefill_offload 中: +k_min = k_cache[:num_valid].min(dim=0).values # [H, D] +k_max = k_cache[:num_valid].max(dim=0).values # [H, D] +``` +- 但这需要不同的估计算法,不能直接用 xattn_estimate + +**决定**: 使用方案 A(下采样 K),因为可以直接复用 xattn_estimate_chunked ## Phases -- [ ] Phase 1: 理解当前代码路径 -- [ ] Phase 2: 实现 sparse mask 估计 -- [ ] Phase 3: 实现 BSA sparse 计算 -- [ ] Phase 4: 测试验证 +- [x] Phase 1: 代码分析,理解当前实现 +- [ ] Phase 2: 实现 on_prefill_offload 收集 K metadata +- [ ] Phase 3: 实现 select_blocks 中的 xattn estimation +- [ ] Phase 4: 实现 BSA block → CPU block 的映射 +- [ ] Phase 5: 测试验证 -## Phase 1: 理解当前代码路径 +## Phase 2: on_prefill_offload 实现 -### 1.1 确认 XATTN_BSA policy 是否被正确加载 -- [ ] 检查 `test_ruler.py` 如何解析 `--sparse-policy XATTN_BSA` -- [ ] 检查 `KVCacheManager` 如何实例化 sparse_policy -- [ ] 运行 baseline 测试(`--sparse-policy FULL`)确认基础功能正常 +### 需要修改的文件 +- `nanovllm/kvcache/sparse/xattn_bsa.py` -### 1.2 确认数据流 -- [ ] `compute_chunked_prefill` 的输入参数含义 -- [ ] `offload_engine` 提供的数据访问接口 -- [ ] 当前 chunk 的 K/V 如何获取 +### 实现细节 -## Phase 2: 实现 sparse mask 估计 +```python +class XAttentionBSAPolicy(SparsePolicy): + def __init__(self, threshold=0.9, stride=8, ...): + self.threshold = threshold + self.stride = stride + self._k_cache: Dict[int, Dict[int, torch.Tensor]] = {} + # _k_cache[layer_id][cpu_block_id] = k_downsampled -### 2.1 调用 xattn_estimate_chunked -- [ ] 在 `compute_chunked_prefill` 中加载历史 K -- [ ] 拼接历史 K + 当前 K -- [ ] 调用 `xattn_estimate_chunked(q, k_full, q_start_pos=...)` -- [ ] 获取 block mask + def initialize(self, num_layers, num_kv_heads, head_dim, num_cpu_blocks, dtype, device): + """初始化 K cache 结构""" + self._k_cache = {layer_id: {} for layer_id in range(num_layers)} + self._num_kv_heads = num_kv_heads + self._head_dim = head_dim -### 2.2 处理参数对齐 -- [ ] BSA block_size = 128 -- [ ] chunk_size 与 kvcache_block_size 的关系 -- [ ] q_start_pos 计算 + def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens): + """收集下采样的 K 用于后续估计""" + # k_cache: [block_size, num_kv_heads, head_dim] + k_downsampled = k_cache[:num_valid_tokens:self.stride].clone() + # k_downsampled: [num_valid_tokens//stride, num_kv_heads, head_dim] + self._k_cache[layer_id][cpu_block_id] = k_downsampled +``` -## Phase 3: 实现 BSA sparse 计算 +## Phase 3: select_blocks 实现 -### 3.1 方案选择 -- 选项 A: 历史 + 当前分开计算,然后 merge -- 选项 B: 全部一起用 BSA 计算 +### 关键问题 -### 3.2 实现 -- [ ] 构造 BSA 需要的输入格式 -- [ ] 调用 `block_sparse_attn_func` -- [ ] 处理输出格式 +1. **Q 从哪里来?** + - `ctx.query` 需要在调用 select_blocks 时传入 + - 当前 FullAttentionPolicy 传递 `query=None` + - 需要修改 compute_chunked_prefill 传递真实的 Q -## Phase 4: 测试验证 +2. **Q 的格式转换** + - 输入 Q: [seq_len, num_heads, head_dim] + - xattn 需要: [B, H, q_len, D] + - 转换: `q.unsqueeze(0).transpose(1, 2)` -### 4.1 单元测试 -- [ ] 验证 sparse mask 与 `test_xattn_estimate_chunked.py` 一致 +3. **K 的组装** + - 从 `_k_cache[layer_id]` 获取各 block 的下采样 K + - 按 `available_blocks` 顺序 cat 起来 + - 结果: [B, H, total_k_downsampled, D] -### 4.2 集成测试 -- [ ] 运行验收命令 -- [ ] 5/5 PASS +### 实现草案 -## Key Questions +```python +def select_blocks(self, available_blocks, offload_engine, ctx): + if not available_blocks or ctx.query is None: + return available_blocks -1. 历史 K 如何高效加载?(全量 vs 按需) -2. BSA causal mask 如何处理?(历史 non-causal + 当前 causal) + layer_id = ctx.layer_id + + # 1. 组装下采样的 K + k_list = [] + for cpu_block_id in available_blocks: + if cpu_block_id in self._k_cache[layer_id]: + k_list.append(self._k_cache[layer_id][cpu_block_id]) + + if not k_list: + return available_blocks + + k_hist = torch.cat(k_list, dim=0) # [total_tokens/stride, H, D] + k_hist = k_hist.unsqueeze(0).transpose(1, 2) # [1, H, k_len, D] + + # 2. 准备 Q + q = ctx.query # [seq_len, num_heads, head_dim] + q = q.unsqueeze(0).transpose(1, 2) # [1, H, q_len, D] + + # GQA 扩展(如果需要) + if q.shape[1] != k_hist.shape[1]: + num_groups = q.shape[1] // k_hist.shape[1] + k_hist = k_hist.repeat_interleave(num_groups, dim=1) + + # 3. 计算 q_start_pos + q_start_pos = len(available_blocks) * ctx.block_size + + # 4. 调用 xattn_estimate_chunked + # 注意:K 已经是下采样的,需要调整参数 + attn_sum, mask = xattn_estimate_chunked( + q, k_hist, + q_start_pos=q_start_pos // self.stride, # 调整到下采样空间 + block_size=self.BSA_BLOCK_SIZE // self.stride, # 16 + stride=1, # K 已经下采样 + threshold=self.threshold, + chunk_size=q.shape[2], # 与 Q 长度一致 + use_triton=self.use_triton, + ) + + # 5. 从 mask 提取 CPU block IDs + # mask: [1, H, q_blocks, k_blocks] + # 对所有 heads 取 OR + selected_mask = mask.any(dim=1).squeeze(0) # [q_blocks, k_blocks] + # 对所有 q_blocks 取 OR(只要任意 Q 位置需要这个 K block) + selected_k_mask = selected_mask.any(dim=0) # [k_blocks] + + # 6. 映射 BSA blocks → CPU blocks + # 每个 CPU block = 8 BSA blocks (block_size=1024, BSA_block=128) + bsa_to_cpu_ratio = ctx.block_size // self.BSA_BLOCK_SIZE # 8 + num_cpu_blocks = len(available_blocks) + + selected_cpu_indices = set() + for bsa_idx in selected_k_mask.nonzero(as_tuple=True)[0].tolist(): + cpu_idx = bsa_idx // bsa_to_cpu_ratio + if cpu_idx < num_cpu_blocks: + selected_cpu_indices.add(cpu_idx) + + selected_blocks = [available_blocks[i] for i in sorted(selected_cpu_indices)] + + logger.info(f"[XAttn] select_blocks: {len(available_blocks)} -> {len(selected_blocks)} " + f"({100*len(selected_blocks)/len(available_blocks):.1f}%)") + + return selected_blocks +``` + +## Phase 4: compute_chunked_prefill + +### 关键修改 + +1. **传递真实的 Q 给 select_blocks** + - 修改 PolicyContext 构造,设置 `query=q` + +2. **复用 FullAttentionPolicy 的 pipeline** + - 继承 FullAttentionPolicy 而不是 SparsePolicy + - 或者直接调用父类方法 + +### 方案对比 + +**方案 A**: XAttentionBSAPolicy 继承 FullAttentionPolicy +```python +class XAttentionBSAPolicy(FullAttentionPolicy): + # 只需要 override select_blocks 和 on_prefill_offload + # compute_chunked_prefill 直接用父类的 +``` + +**方案 B**: 独立实现,调用相同的 pipeline 代码 +```python +class XAttentionBSAPolicy(SparsePolicy): + def compute_chunked_prefill(self, q, k, v, ...): + # 复制 FullAttentionPolicy 的代码 + # 但修改 PolicyContext 传递 query=q +``` + +**决定**: 使用方案 B,因为需要在 compute_chunked_prefill 中修改 PolicyContext + +## Phase 5: 测试 + +### 单元测试 +```bash +# 测试 select_blocks 的 sparsity +python -c " +from nanovllm.kvcache.sparse.xattn_bsa import XAttentionBSAPolicy +policy = XAttentionBSAPolicy(threshold=0.9) +# ... 测试代码 +" +``` + +### 集成测试 +```bash +CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \ + python tests/test_ruler.py \ + --model ~/models/Llama-3.1-8B-Instruct \ + --enable-offload \ + --sparse-policy XATTN_BSA \ + --datasets niah_single_1 \ + --sample-indices 0,1,2,3,4 +``` + +## Key Decisions + +| 决策 | 理由 | +|------|------| +| 使用下采样 K 作为 metadata | 可以直接复用 xattn_estimate_chunked | +| stride=8 | 平衡内存和精度 | +| BSA blocks → CPU blocks 映射用 OR | 只要有一个 BSA block 被选中就保留 | +| 继承 FullAttentionPolicy 的 pipeline | 复用已验证的 ring buffer 流程 | + +## Files to Modify + +| 文件 | 修改 | +|------|------| +| `nanovllm/kvcache/sparse/xattn_bsa.py` | 主要实现:initialize, on_prefill_offload, select_blocks | + +## 注意事项 + +1. **GQA 处理**: Llama-3.1-8B 有 32 query heads, 8 kv heads,需要在估计时扩展 K +2. **内存管理**: `_k_cache` 存储在 GPU,需要在 reset() 时清理 +3. **Triton 兼容性**: xattn_estimate_chunked 有 Triton bug,可能需要用 PyTorch fallback +4. **边界条件**: 第一个 chunk (available_blocks=[]) 时直接返回空列表 + +## Errors Encountered + +(待填充) ## Status -**Currently in Phase 1** - 等待用户确认后开始 - -## 待讨论 - -请确认: -1. 这个 goal 和验收标准是否正确? -2. 我使用哪个 GPU 运行测试? +**Currently in Phase 1** - 代码分析完成,准备开始 Phase 2 实现 diff --git a/tests/test_xattn_bsa.py b/tests/test_xattn_bsa.py new file mode 100644 index 0000000..cd6529a --- /dev/null +++ b/tests/test_xattn_bsa.py @@ -0,0 +1,334 @@ +""" +Test XAttention + BSA with RULER benchmark data. + +Tests XAttention sparse attention correctness using RULER NIAH task. + +Attention methods: + - Prefill: XAttention + BSA (sparse) or FlashAttention (dense) + - Decode: FlashAttention (always, since q_len=1) + +Usage (in compass conda env with BSA available): + CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \ + python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct + + # Test with XAttention + BSA for prefill (default) + python tests/test_xattn_bsa.py --prefill-method xattn + + # Test with FlashAttention for prefill (baseline) + python tests/test_xattn_bsa.py --prefill-method flash + + # Test specific sample(s) + python tests/test_xattn_bsa.py --sample-id 0 + python tests/test_xattn_bsa.py --sample-ids 0,1,2 + +Note: Compatible with transformers 4.53+ (handles both old `past_key_value` + and new `past_key_values` API). +""" + +import argparse +import json +import sys +import torch +from pathlib import Path +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.cache_utils import DynamicCache + +from nanovllm.ops.xattn import xattn_estimate +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + + +# ============================================================ +# XAttention + BSA Functions +# ============================================================ + +def expand_kv_for_gqa(key_states, value_states, num_heads): + """Expand KV for Grouped Query Attention.""" + num_kv_heads = key_states.shape[1] + if num_heads == num_kv_heads: + return key_states, value_states + num_groups = num_heads // num_kv_heads + return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1) + + +def flash_attention_forward(query_states, key_states, value_states, is_causal=True): + """Standard FlashAttention.""" + from flash_attn import flash_attn_func + q = query_states.transpose(1, 2) + k = key_states.transpose(1, 2) + v = value_states.transpose(1, 2) + return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2) + + +def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9): + """XAttention + BSA sparse attention.""" + from block_sparse_attn import block_sparse_attn_func + + batch_size, num_heads, q_len, head_dim = query_states.shape + k_len = key_states.shape[2] + + _, mask = xattn_estimate( + query_states, key_states, + chunk_size=16384, block_size=128, threshold=threshold, + use_triton=True, causal=True, + ) + + q_block_num = (q_len + 127) // 128 + k_block_num = (k_len + 127) // 128 + + q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim) + k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim) + v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim) + + __import__('pdb').set_trace() + + output = block_sparse_attn_func( + q, k, v, + torch.tensor([0, q_len], dtype=torch.int32, device=q.device), + torch.tensor([0, k_len], dtype=torch.int32, device=k.device), + torch.ones(num_heads, dtype=torch.int32, device=q.device), + None, + mask[:, :, :q_block_num, :k_block_num].contiguous(), + q_len, k_len, + p_dropout=0.0, deterministic=True, is_causal=True, + ) + return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2) + + +DEBUG = False # Set to True to enable debugging + +def create_patched_forward(prefill_method="xattn", threshold=0.9): + """Create patched forward with configurable prefill method. + + Args: + prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense) + threshold: XAttention threshold for block selection (only used when prefill_method="xattn") + + Note: + - Prefill (q_len > 1): Uses specified prefill_method + - Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query) + """ + call_count = [0] # Mutable to track calls across layers + + def patched_forward( + self, + hidden_states, + position_embeddings=None, + attention_mask=None, + past_key_value=None, # Old API (transformers < 4.57) + past_key_values=None, # New API (transformers >= 4.57) + cache_position=None, + **kwargs + ): + # Handle both old and new transformers API + kv_cache = past_key_values if past_key_values is not None else past_key_value + + bsz, q_len, _ = hidden_states.size() + num_heads = self.config.num_attention_heads + num_kv_heads = self.config.num_key_value_heads + head_dim = self.head_dim + + # Compute Q, K, V projections + query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2) + + # Apply rotary position embedding + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # Handle KV cache + if kv_cache is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = kv_cache.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # Expand KV for GQA + key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads) + + # Debug output + if DEBUG and self.layer_idx == 0: + call_count[0] += 1 + if call_count[0] <= 5: + phase = "prefill" if q_len > 1 else "decode" + print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}") + print(f" kv_cache is None: {kv_cache is None}") + + # Choose attention method: + # - Prefill (q_len > 1): Use prefill_method (xattn or flash) + # - Decode (q_len = 1): Always use FlashAttention + is_prefill = q_len > 1 + + if is_prefill and prefill_method == "xattn": + # Prefill with XAttention + BSA (sparse) + attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold) + else: + # Prefill with FlashAttention (dense) OR Decode (always FlashAttention) + # Note: For decode (q_len=1), causal=False since single query attends to all KV + attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill) + + attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1)) + return attn_output, None + + return patched_forward + + +# ============================================================ +# Data & Evaluation +# ============================================================ + +def load_samples(filepath, indices=None): + """Load samples from JSONL file.""" + samples = [] + with open(filepath) as f: + for i, line in enumerate(f): + if indices is None or i in indices: + sample = json.loads(line) + sample["_idx"] = i + samples.append(sample) + return samples + + +def string_match_all(output_text, expected_list): + """RULER metric: fraction of expected values found in output.""" + output_lower = output_text.lower().replace('\n', ' ') + if not expected_list: + return 1.0 + return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list) + + +# ============================================================ +# Test +# ============================================================ + +def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50): + """Test attention methods using RULER data. + + Args: + prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention + """ + prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)" + + print("=" * 60) + print("RULER NIAH Attention Test") + print("=" * 60) + print(f"Data: {data_file}") + print(f"Samples: {sample_ids}") + print(f"Prefill method: {prefill_desc}") + print(f"Decode method: FlashAttention (always)") + if prefill_method == "xattn": + print(f"XAttention threshold: {threshold}") + + samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None) + if not samples: + print("No samples found!") + return False + print(f"Loaded {len(samples)} samples") + + # Load model + print(f"\nLoading model: {model_path}") + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16, device_map="cuda", + attn_implementation="eager", # Will be patched + ) + model.eval() + + # Patch all layers + print(f"Patching attention layers...") + print(f" - Prefill: {prefill_desc}") + print(f" - Decode: FlashAttention") + for idx, layer in enumerate(model.model.layers): + layer.self_attn.layer_idx = idx # Ensure layer_idx is set + layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__( + layer.self_attn, type(layer.self_attn) + ) + + total_score = 0.0 + results = [] + + for sample in samples: + idx = sample["_idx"] + prompt = sample["input"] + expected = sample["outputs"] + + inputs = tokenizer(prompt, return_tensors="pt").to("cuda") + num_tokens = inputs["input_ids"].shape[1] + print(f"\n--- Sample {idx} ({num_tokens} tokens) ---") + print(f"Expected: {expected}") + + with torch.no_grad(): + output = model.generate( + inputs["input_ids"], + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True) + score = string_match_all(output_text, expected) + total_score += score + + status = "✓ PASS" if score >= 0.5 else "✗ FAIL" + print(f"Output: '{output_text[:100]}...'") + print(f"Result: {status} (score={score:.2f})") + results.append({"idx": idx, "score": score, "passed": score >= 0.5}) + + avg_score = total_score / len(samples) + passed = sum(1 for r in results if r["passed"]) + + print(f"\n{'='*60}") + print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}") + print(f"{'='*60}") + + return avg_score >= 0.5 + + +def main(): + parser = argparse.ArgumentParser( + description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark" + ) + parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct") + parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl") + parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index") + parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)") + parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn", + help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)") + parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)") + parser.add_argument("--max-new-tokens", type=int, default=50) + # Keep old option for backwards compatibility + parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead") + args = parser.parse_args() + + model_path = args.model.replace("~", "/home/zijie") + + # Handle deprecated --no-xattn option + prefill_method = args.prefill_method + if args.no_xattn: + prefill_method = "flash" + print("Warning: --no-xattn is deprecated, use --prefill-method flash instead") + + if args.sample_id is not None: + sample_ids = [args.sample_id] + elif args.sample_ids: + sample_ids = [int(x) for x in args.sample_ids.split(",")] + else: + sample_ids = [0] + + # Check BSA availability if using xattn + if prefill_method == "xattn": + try: + from block_sparse_attn import block_sparse_attn_func + print("✓ BSA (Block Sparse Attention) available") + except ImportError: + print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash") + sys.exit(1) + + if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens): + print("\ntest_xattn_bsa: PASSED") + else: + print("\ntest_xattn_bsa: FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/tests/test_xattn_chunked.py b/tests/test_xattn_chunked.py new file mode 100644 index 0000000..d6fc4c6 --- /dev/null +++ b/tests/test_xattn_chunked.py @@ -0,0 +1,259 @@ +""" +Test: Compare xattn_estimate vs xattn_estimate_chunked +Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation. + +Uses real QKV data captured from model inference. +""" + +import sys +import os +import torch +import warnings + +from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked + +# ============================================================ +# Configuration +# ============================================================ + +BLOCK_SIZE = 64 +STRIDE = 4 +THRESHOLD = 0.9 +CHUNK_SIZE = 4096 + +# Default QKV data directory (relative to project root) +DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache") + +# ============================================================ +# Utility Functions +# ============================================================ + +def load_qkv(path): + """Load saved QKV data.""" + data = torch.load(path, map_location="cpu", weights_only=False) + print(f"Loaded: {path}") + print(f" Query shape: {data['query'].shape}") + print(f" Key shape: {data['key'].shape}") + print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}") + return data + + +def compare_masks(mask1, mask2, name1="standard", name2="chunked"): + """Compare two masks and report differences.""" + if mask1.shape != mask2.shape: + print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}") + return False + + diff = (mask1 != mask2).sum().item() + total = mask1.numel() + match_rate = (total - diff) / total * 100 + + print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})") + + if diff > 0: + diff_indices = torch.where(mask1 != mask2) + print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}") + + return diff == 0 + + +def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size): + """ + Run xattn_estimate_chunked with EXTERNAL chunking. + This simulates how chunked prefill should be used in practice. + """ + batch_size, num_heads, q_len, head_dim = query.shape + _, _, k_len, _ = key.shape + + q_block_num = (q_len + block_size - 1) // block_size + k_block_num = (k_len + block_size - 1) // block_size + + # If Q fits in one chunk, call directly + if q_len <= chunk_size: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return xattn_estimate_chunked( + query, key, + q_start_pos=q_start_pos, + block_size=block_size, + stride=stride, + threshold=threshold, + use_triton=True, + chunk_size=chunk_size, + ) + + # External chunking: split Q and call for each chunk + num_q_chunks = (q_len + chunk_size - 1) // chunk_size + print(f" External chunking: {num_q_chunks} chunks") + + combined_attn_sum = torch.zeros( + batch_size, num_heads, q_block_num, k_block_num, + dtype=query.dtype, device=query.device + ) + combined_mask = torch.zeros( + batch_size, num_heads, q_block_num, k_block_num, + dtype=torch.bool, device=query.device + ) + + q_block_offset = 0 + for q_chunk_idx in range(num_q_chunks): + q_chunk_start = q_chunk_idx * chunk_size + q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len) + + q_chunk = query[:, :, q_chunk_start:q_chunk_end, :] + + # For causal attention, K accumulates up to current Q position + k_end = q_start_pos + q_chunk_end + k_chunk = key[:, :, :k_end, :] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + attn_sum_chunk, mask_chunk = xattn_estimate_chunked( + q_chunk, k_chunk, + q_start_pos=q_start_pos + q_chunk_start, + block_size=block_size, + stride=stride, + threshold=threshold, + use_triton=True, + chunk_size=chunk_size, + ) + + # Place chunk results into combined output + chunk_q_blocks = mask_chunk.shape[2] + chunk_k_blocks = mask_chunk.shape[3] + combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk + combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk + q_block_offset += chunk_q_blocks + + return combined_attn_sum, combined_mask + + +def test_single_qkv(qkv_path): + """Test a single QKV file.""" + data = load_qkv(qkv_path) + query = data["query"].cuda().to(torch.bfloat16) + key = data["key"].cuda().to(torch.bfloat16) + + seq_len = query.shape[2] + print(f"\nTesting with seq_len={seq_len}") + print("=" * 60) + + # Run standard xattn_estimate + print("[1] Running standard xattn_estimate...") + try: + attn_sum_std, mask_std = xattn_estimate( + query, key, + block_size=BLOCK_SIZE, + stride=STRIDE, + threshold=THRESHOLD, + chunk_size=CHUNK_SIZE, + use_triton=True, + ) + print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}") + except Exception as e: + print(f" ERROR: {e}") + import traceback + traceback.print_exc() + return False + + # Run chunked xattn_estimate with EXTERNAL chunking + print("[2] Running chunked xattn_estimate (external chunking)...") + try: + attn_sum_chunked, mask_chunked = run_chunked_externally( + query, key, + q_start_pos=0, + block_size=BLOCK_SIZE, + stride=STRIDE, + threshold=THRESHOLD, + chunk_size=CHUNK_SIZE, + ) + print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}") + except Exception as e: + print(f" ERROR: {e}") + import traceback + traceback.print_exc() + return False + + # Compare results + print("[3] Comparing results...") + chunked_q_blocks = mask_chunked.shape[2] + chunked_k_blocks = mask_chunked.shape[3] + + # Extract comparable region from standard mask + mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks] + + # Compare masks + masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked") + + # Compare attn_sums + attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks] + if attn_sum_std_comparable.shape == attn_sum_chunked.shape: + attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item() + print(f" Attn sum max diff: {attn_diff:.6f}") + else: + print(f" Attn sum shape mismatch") + + # Clean up GPU memory + del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked + torch.cuda.empty_cache() + + return masks_match + + +# ============================================================ +# Main Test +# ============================================================ + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked") + parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR, + help="Directory containing QKV files") + args = parser.parse_args() + + # QKV files to test + qkv_files = [ + os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K + os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K + os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K + os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K + os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K + ] + + available_files = [p for p in qkv_files if os.path.exists(p)] + + if not available_files: + print(f"No QKV file found in {args.qkv_dir}.") + print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt") + sys.exit(1) + + print(f"Found {len(available_files)} QKV files to test") + print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})") + print(f"Using Triton kernels") + + all_passed = True + results = [] + + for qkv_path in available_files: + passed = test_single_qkv(qkv_path) + seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", "")) + results.append((seq_len, passed)) + if not passed: + all_passed = False + + # Summary + print("\n" + "=" * 60) + print("SUMMARY") + print("=" * 60) + for seq_len, passed in results: + status = "PASSED" if passed else "FAILED" + chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE + print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}") + + print("=" * 60) + if all_passed: + print("test_xattn_chunked: PASSED") + sys.exit(0) + else: + print("test_xattn_chunked: FAILED") + sys.exit(1) diff --git a/tests/test_xattn_kernels.py b/tests/test_xattn_kernels.py index 57fcd24..b4800c8 100644 --- a/tests/test_xattn_kernels.py +++ b/tests/test_xattn_kernels.py @@ -6,9 +6,10 @@ Test: XAttention Triton kernels 2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和 数据流: - Q, K [batch, heads, seq_len, head_dim] + Q [batch, heads, q_len, head_dim] + K [batch, heads, kv_len, head_dim] ↓ flat_group_gemm_fuse_reshape - attn_scores [batch, heads, seq_len/stride, seq_len/stride] + attn_scores [batch, heads, q_len/stride, kv_len/stride] ↓ softmax_fuse_block_sum block_sums [batch, heads, q_blocks, k_blocks] """ @@ -21,7 +22,11 @@ from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_ # 参数配置 # ============================================================ -seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M = 4 * 128 = 512 +# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N +# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512 +# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256 +q_len = 512 +kv_len = 2048 head_dim = 128 stride = 4 block_size = 128 # softmax block size (in reshaped space) @@ -31,26 +36,56 @@ segment_size = 128 # Triton kernel 要求 segment_size >= block_size # 构造输入: 偶数位置=1, 奇数位置=2 # ============================================================ -Q = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda() -K = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda() -for i in range(seq_len): +Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda() +K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda() + +for i in range(q_len): if i % 2 == 0: Q[0, 0, i, :] = 1 - K[0, 0, i, :] = 1 else: Q[0, 0, i, :] = 2 + +for i in range(kv_len): + if i % 2 == 0: + K[0, 0, i, :] = 1 + else: K[0, 0, i, :] = 2 # ============================================================ -# Step 1: flat_group_gemm_fuse_reshape +# Step 1: flat_group_gemm_fuse_reshape (chunked along K) # ============================================================ -attn_scores = flat_group_gemm_fuse_reshape( - Q, K, stride, - chunk_start=0, - chunk_end=seq_len // stride, - is_causal=False -) +q_reshaped_len = q_len // stride # 128 +kv_reshaped_len = kv_len // stride # 512 + +# 将 K 沿着长度维度分成多个 chunk +k_chunk_size = 512 # 每个 chunk 512 tokens +num_k_chunks = kv_len // k_chunk_size # 4 chunks + +attn_scores_list = [] +for k_chunk_idx in range(num_k_chunks): + k_start = k_chunk_idx * k_chunk_size + k_end = k_start + k_chunk_size + K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim] + + # 对每个 K chunk 调用 flat_group_gemm_fuse_reshape + # 输出: [batch, heads, q_len/stride, k_chunk_size/stride] + attn_chunk = flat_group_gemm_fuse_reshape( + Q, K_chunk, stride, + chunk_start=0, + chunk_end=q_reshaped_len, + is_causal=False + ) + attn_scores_list.append(attn_chunk) + +# 拼接所有 K chunks 的结果 +# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride] +# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len] +attn_scores = torch.cat(attn_scores_list, dim=-1) + +# 验证 shape: [batch, heads, q_len/stride, kv_len/stride] +assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \ + f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})" # 验证: 反对角线求和 # 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4 @@ -63,7 +98,6 @@ assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expect # Step 2: softmax_fuse_block_sum # ============================================================ -reshaped_len = seq_len // stride scale = 1.4426950408889634 # log2(e) for exp2 block_sums = softmax_fuse_block_sum( @@ -71,15 +105,24 @@ block_sums = softmax_fuse_block_sum( block_size, segment_size, chunk_start=0, - chunk_end=reshaped_len, - real_q_len=reshaped_len, + chunk_end=q_reshaped_len, + real_q_len=q_reshaped_len, scale=scale, is_causal=False ) +# 验证 shape: [batch, heads, q_blocks, k_blocks] +q_blocks = q_reshaped_len // block_size # 128 / 128 = 1 +k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4 +assert block_sums.shape == (1, 1, q_blocks, k_blocks), \ + f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})" + # 验证: 每个 block 的 softmax 结果求和 -# 所有 attn_scores 相同 → softmax 均匀分布 → block_sum = block_size^2 / reshaped_len -expected_sum = block_size * block_size / reshaped_len +# 所有 attn_scores 相同 → softmax 均匀分布 +# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len +# 每个 Q block 有 block_size 行 +# block_sum = block_size * (block_size / kv_reshaped_len) +expected_sum = block_size * block_size / kv_reshaped_len actual_sum = block_sums[0, 0, 0, 0].item() assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"