From c48753fc4e5ad8929b527d50ec577f7cdb74ae02 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Tue, 20 Jan 2026 07:07:56 +0800 Subject: [PATCH] [WIP] Before work start. --- .claude/rules/plan-execution.md | 44 +++++ tests/test_xattn_bsa.py | 332 ++++++++++++++++++++++++++++++++ 2 files changed, 376 insertions(+) create mode 100644 .claude/rules/plan-execution.md create mode 100644 tests/test_xattn_bsa.py diff --git a/.claude/rules/plan-execution.md b/.claude/rules/plan-execution.md new file mode 100644 index 0000000..a4178f8 --- /dev/null +++ b/.claude/rules/plan-execution.md @@ -0,0 +1,44 @@ +# Plan 执行规则 + +## 严格禁止未经许可执行 Plan + +**重要**:在用户明确发出以下指示之前,**绝对不能**执行 plan: +- "执行 plan" +- "开始实现" +- "go ahead" +- 或其他明确的开始信号 + +### 允许的操作 + +在等待用户确认期间,只能进行以下操作: + +| 操作类型 | 允许 | 示例 | +|---------|------|------| +| 阅读代码 | ✅ | 用 Read/Grep/Glob 探索代码库 | +| 修改 plan 文件 | ✅ | 更新 task_plan.md, findings.md | +| 讨论设计 | ✅ | 回答用户关于 plan 的问题 | +| **修改源代码** | ❌ | 编辑 .py 文件 | +| **创建新源文件** | ❌ | 新建 .py 文件 | +| **运行测试/benchmark** | ❌ | 执行 python 脚本 | + +### 工作流程 + +``` +用户: "请阅读 plan 并优化" +Claude: + 1. 阅读 task_plan.md + 2. 阅读相关代码 + 3. 更新 task_plan.md 中的计划 + 4. 等待用户确认 + +用户: "执行 plan" +Claude: + 1. 现在可以开始修改代码 + 2. 按照 plan 中的步骤执行 +``` + +### 原因 + +- 防止在设计未完善时过早修改代码 +- 给用户审查计划的机会 +- 避免不必要的重构 diff --git a/tests/test_xattn_bsa.py b/tests/test_xattn_bsa.py new file mode 100644 index 0000000..c59b9d7 --- /dev/null +++ b/tests/test_xattn_bsa.py @@ -0,0 +1,332 @@ +""" +Test XAttention + BSA with RULER benchmark data. + +Tests XAttention sparse attention correctness using RULER NIAH task. + +Attention methods: + - Prefill: XAttention + BSA (sparse) or FlashAttention (dense) + - Decode: FlashAttention (always, since q_len=1) + +Usage (in compass conda env with BSA available): + CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \ + python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct + + # Test with XAttention + BSA for prefill (default) + python tests/test_xattn_bsa.py --prefill-method xattn + + # Test with FlashAttention for prefill (baseline) + python tests/test_xattn_bsa.py --prefill-method flash + + # Test specific sample(s) + python tests/test_xattn_bsa.py --sample-id 0 + python tests/test_xattn_bsa.py --sample-ids 0,1,2 + +Note: Compatible with transformers 4.53+ (handles both old `past_key_value` + and new `past_key_values` API). +""" + +import argparse +import json +import sys +import torch +from pathlib import Path +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.cache_utils import DynamicCache + +from nanovllm.ops.xattn import xattn_estimate +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb + + +# ============================================================ +# XAttention + BSA Functions +# ============================================================ + +def expand_kv_for_gqa(key_states, value_states, num_heads): + """Expand KV for Grouped Query Attention.""" + num_kv_heads = key_states.shape[1] + if num_heads == num_kv_heads: + return key_states, value_states + num_groups = num_heads // num_kv_heads + return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1) + + +def flash_attention_forward(query_states, key_states, value_states, is_causal=True): + """Standard FlashAttention.""" + from flash_attn import flash_attn_func + q = query_states.transpose(1, 2) + k = key_states.transpose(1, 2) + v = value_states.transpose(1, 2) + return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2) + + +def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9): + """XAttention + BSA sparse attention.""" + from block_sparse_attn import block_sparse_attn_func + + batch_size, num_heads, q_len, head_dim = query_states.shape + k_len = key_states.shape[2] + + _, mask = xattn_estimate( + query_states, key_states, + chunk_size=16384, block_size=128, threshold=threshold, + use_triton=True, causal=True, + ) + + q_block_num = (q_len + 127) // 128 + k_block_num = (k_len + 127) // 128 + + q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim) + k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim) + v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim) + + output = block_sparse_attn_func( + q, k, v, + torch.tensor([0, q_len], dtype=torch.int32, device=q.device), + torch.tensor([0, k_len], dtype=torch.int32, device=k.device), + torch.ones(num_heads, dtype=torch.int32, device=q.device), + None, + mask[:, :, :q_block_num, :k_block_num].contiguous(), + q_len, k_len, + p_dropout=0.0, deterministic=True, is_causal=True, + ) + return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2) + + +DEBUG = False # Set to True to enable debugging + +def create_patched_forward(prefill_method="xattn", threshold=0.9): + """Create patched forward with configurable prefill method. + + Args: + prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense) + threshold: XAttention threshold for block selection (only used when prefill_method="xattn") + + Note: + - Prefill (q_len > 1): Uses specified prefill_method + - Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query) + """ + call_count = [0] # Mutable to track calls across layers + + def patched_forward( + self, + hidden_states, + position_embeddings=None, + attention_mask=None, + past_key_value=None, # Old API (transformers < 4.57) + past_key_values=None, # New API (transformers >= 4.57) + cache_position=None, + **kwargs + ): + # Handle both old and new transformers API + kv_cache = past_key_values if past_key_values is not None else past_key_value + + bsz, q_len, _ = hidden_states.size() + num_heads = self.config.num_attention_heads + num_kv_heads = self.config.num_key_value_heads + head_dim = self.head_dim + + # Compute Q, K, V projections + query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2) + + # Apply rotary position embedding + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + # Handle KV cache + if kv_cache is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = kv_cache.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # Expand KV for GQA + key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads) + + # Debug output + if DEBUG and self.layer_idx == 0: + call_count[0] += 1 + if call_count[0] <= 5: + phase = "prefill" if q_len > 1 else "decode" + print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}") + print(f" kv_cache is None: {kv_cache is None}") + + # Choose attention method: + # - Prefill (q_len > 1): Use prefill_method (xattn or flash) + # - Decode (q_len = 1): Always use FlashAttention + is_prefill = q_len > 1 + + if is_prefill and prefill_method == "xattn": + # Prefill with XAttention + BSA (sparse) + attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold) + else: + # Prefill with FlashAttention (dense) OR Decode (always FlashAttention) + # Note: For decode (q_len=1), causal=False since single query attends to all KV + attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill) + + attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1)) + return attn_output, None + + return patched_forward + + +# ============================================================ +# Data & Evaluation +# ============================================================ + +def load_samples(filepath, indices=None): + """Load samples from JSONL file.""" + samples = [] + with open(filepath) as f: + for i, line in enumerate(f): + if indices is None or i in indices: + sample = json.loads(line) + sample["_idx"] = i + samples.append(sample) + return samples + + +def string_match_all(output_text, expected_list): + """RULER metric: fraction of expected values found in output.""" + output_lower = output_text.lower().replace('\n', ' ') + if not expected_list: + return 1.0 + return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list) + + +# ============================================================ +# Test +# ============================================================ + +def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50): + """Test attention methods using RULER data. + + Args: + prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention + """ + prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)" + + print("=" * 60) + print("RULER NIAH Attention Test") + print("=" * 60) + print(f"Data: {data_file}") + print(f"Samples: {sample_ids}") + print(f"Prefill method: {prefill_desc}") + print(f"Decode method: FlashAttention (always)") + if prefill_method == "xattn": + print(f"XAttention threshold: {threshold}") + + samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None) + if not samples: + print("No samples found!") + return False + print(f"Loaded {len(samples)} samples") + + # Load model + print(f"\nLoading model: {model_path}") + tokenizer = AutoTokenizer.from_pretrained(model_path) + model = AutoModelForCausalLM.from_pretrained( + model_path, torch_dtype=torch.float16, device_map="cuda", + attn_implementation="eager", # Will be patched + ) + model.eval() + + # Patch all layers + print(f"Patching attention layers...") + print(f" - Prefill: {prefill_desc}") + print(f" - Decode: FlashAttention") + for idx, layer in enumerate(model.model.layers): + layer.self_attn.layer_idx = idx # Ensure layer_idx is set + layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__( + layer.self_attn, type(layer.self_attn) + ) + + total_score = 0.0 + results = [] + + for sample in samples: + idx = sample["_idx"] + prompt = sample["input"] + expected = sample["outputs"] + + inputs = tokenizer(prompt, return_tensors="pt").to("cuda") + num_tokens = inputs["input_ids"].shape[1] + print(f"\n--- Sample {idx} ({num_tokens} tokens) ---") + print(f"Expected: {expected}") + + with torch.no_grad(): + output = model.generate( + inputs["input_ids"], + max_new_tokens=max_new_tokens, + do_sample=False, + pad_token_id=tokenizer.eos_token_id, + ) + output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True) + score = string_match_all(output_text, expected) + total_score += score + + status = "✓ PASS" if score >= 0.5 else "✗ FAIL" + print(f"Output: '{output_text[:100]}...'") + print(f"Result: {status} (score={score:.2f})") + results.append({"idx": idx, "score": score, "passed": score >= 0.5}) + + avg_score = total_score / len(samples) + passed = sum(1 for r in results if r["passed"]) + + print(f"\n{'='*60}") + print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}") + print(f"{'='*60}") + + return avg_score >= 0.5 + + +def main(): + parser = argparse.ArgumentParser( + description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark" + ) + parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct") + parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl") + parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index") + parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)") + parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn", + help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)") + parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)") + parser.add_argument("--max-new-tokens", type=int, default=50) + # Keep old option for backwards compatibility + parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead") + args = parser.parse_args() + + model_path = args.model.replace("~", "/home/zijie") + + # Handle deprecated --no-xattn option + prefill_method = args.prefill_method + if args.no_xattn: + prefill_method = "flash" + print("Warning: --no-xattn is deprecated, use --prefill-method flash instead") + + if args.sample_id is not None: + sample_ids = [args.sample_id] + elif args.sample_ids: + sample_ids = [int(x) for x in args.sample_ids.split(",")] + else: + sample_ids = [0] + + # Check BSA availability if using xattn + if prefill_method == "xattn": + try: + from block_sparse_attn import block_sparse_attn_func + print("✓ BSA (Block Sparse Attention) available") + except ImportError: + print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash") + sys.exit(1) + + if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens): + print("\ntest_xattn_bsa: PASSED") + else: + print("\ntest_xattn_bsa: FAILED") + sys.exit(1) + + +if __name__ == "__main__": + main()