[WIP] Before work start.
This commit is contained in:
332
tests/test_xattn_bsa.py
Normal file
332
tests/test_xattn_bsa.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
Test XAttention + BSA with RULER benchmark data.
|
||||
|
||||
Tests XAttention sparse attention correctness using RULER NIAH task.
|
||||
|
||||
Attention methods:
|
||||
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
|
||||
- Decode: FlashAttention (always, since q_len=1)
|
||||
|
||||
Usage (in compass conda env with BSA available):
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
|
||||
|
||||
# Test with XAttention + BSA for prefill (default)
|
||||
python tests/test_xattn_bsa.py --prefill-method xattn
|
||||
|
||||
# Test with FlashAttention for prefill (baseline)
|
||||
python tests/test_xattn_bsa.py --prefill-method flash
|
||||
|
||||
# Test specific sample(s)
|
||||
python tests/test_xattn_bsa.py --sample-id 0
|
||||
python tests/test_xattn_bsa.py --sample-ids 0,1,2
|
||||
|
||||
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
|
||||
and new `past_key_values` API).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
|
||||
# ============================================================
|
||||
# XAttention + BSA Functions
|
||||
# ============================================================
|
||||
|
||||
def expand_kv_for_gqa(key_states, value_states, num_heads):
|
||||
"""Expand KV for Grouped Query Attention."""
|
||||
num_kv_heads = key_states.shape[1]
|
||||
if num_heads == num_kv_heads:
|
||||
return key_states, value_states
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
|
||||
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
|
||||
"""Standard FlashAttention."""
|
||||
from flash_attn import flash_attn_func
|
||||
q = query_states.transpose(1, 2)
|
||||
k = key_states.transpose(1, 2)
|
||||
v = value_states.transpose(1, 2)
|
||||
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
|
||||
|
||||
|
||||
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
|
||||
"""XAttention + BSA sparse attention."""
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
|
||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||
k_len = key_states.shape[2]
|
||||
|
||||
_, mask = xattn_estimate(
|
||||
query_states, key_states,
|
||||
chunk_size=16384, block_size=128, threshold=threshold,
|
||||
use_triton=True, causal=True,
|
||||
)
|
||||
|
||||
q_block_num = (q_len + 127) // 128
|
||||
k_block_num = (k_len + 127) // 128
|
||||
|
||||
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
|
||||
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
||||
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
||||
|
||||
output = block_sparse_attn_func(
|
||||
q, k, v,
|
||||
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
|
||||
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
|
||||
torch.ones(num_heads, dtype=torch.int32, device=q.device),
|
||||
None,
|
||||
mask[:, :, :q_block_num, :k_block_num].contiguous(),
|
||||
q_len, k_len,
|
||||
p_dropout=0.0, deterministic=True, is_causal=True,
|
||||
)
|
||||
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
DEBUG = False # Set to True to enable debugging
|
||||
|
||||
def create_patched_forward(prefill_method="xattn", threshold=0.9):
|
||||
"""Create patched forward with configurable prefill method.
|
||||
|
||||
Args:
|
||||
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
|
||||
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
|
||||
|
||||
Note:
|
||||
- Prefill (q_len > 1): Uses specified prefill_method
|
||||
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
|
||||
"""
|
||||
call_count = [0] # Mutable to track calls across layers
|
||||
|
||||
def patched_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
position_embeddings=None,
|
||||
attention_mask=None,
|
||||
past_key_value=None, # Old API (transformers < 4.57)
|
||||
past_key_values=None, # New API (transformers >= 4.57)
|
||||
cache_position=None,
|
||||
**kwargs
|
||||
):
|
||||
# Handle both old and new transformers API
|
||||
kv_cache = past_key_values if past_key_values is not None else past_key_value
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
num_heads = self.config.num_attention_heads
|
||||
num_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
|
||||
# Compute Q, K, V projections
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary position embedding
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Handle KV cache
|
||||
if kv_cache is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = kv_cache.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# Expand KV for GQA
|
||||
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
|
||||
|
||||
# Debug output
|
||||
if DEBUG and self.layer_idx == 0:
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 5:
|
||||
phase = "prefill" if q_len > 1 else "decode"
|
||||
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
|
||||
print(f" kv_cache is None: {kv_cache is None}")
|
||||
|
||||
# Choose attention method:
|
||||
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
|
||||
# - Decode (q_len = 1): Always use FlashAttention
|
||||
is_prefill = q_len > 1
|
||||
|
||||
if is_prefill and prefill_method == "xattn":
|
||||
# Prefill with XAttention + BSA (sparse)
|
||||
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
|
||||
else:
|
||||
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
|
||||
# Note: For decode (q_len=1), causal=False since single query attends to all KV
|
||||
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
|
||||
|
||||
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
|
||||
return attn_output, None
|
||||
|
||||
return patched_forward
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Data & Evaluation
|
||||
# ============================================================
|
||||
|
||||
def load_samples(filepath, indices=None):
|
||||
"""Load samples from JSONL file."""
|
||||
samples = []
|
||||
with open(filepath) as f:
|
||||
for i, line in enumerate(f):
|
||||
if indices is None or i in indices:
|
||||
sample = json.loads(line)
|
||||
sample["_idx"] = i
|
||||
samples.append(sample)
|
||||
return samples
|
||||
|
||||
|
||||
def string_match_all(output_text, expected_list):
|
||||
"""RULER metric: fraction of expected values found in output."""
|
||||
output_lower = output_text.lower().replace('\n', ' ')
|
||||
if not expected_list:
|
||||
return 1.0
|
||||
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Test
|
||||
# ============================================================
|
||||
|
||||
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
|
||||
"""Test attention methods using RULER data.
|
||||
|
||||
Args:
|
||||
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
|
||||
"""
|
||||
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
|
||||
|
||||
print("=" * 60)
|
||||
print("RULER NIAH Attention Test")
|
||||
print("=" * 60)
|
||||
print(f"Data: {data_file}")
|
||||
print(f"Samples: {sample_ids}")
|
||||
print(f"Prefill method: {prefill_desc}")
|
||||
print(f"Decode method: FlashAttention (always)")
|
||||
if prefill_method == "xattn":
|
||||
print(f"XAttention threshold: {threshold}")
|
||||
|
||||
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
|
||||
if not samples:
|
||||
print("No samples found!")
|
||||
return False
|
||||
print(f"Loaded {len(samples)} samples")
|
||||
|
||||
# Load model
|
||||
print(f"\nLoading model: {model_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map="cuda",
|
||||
attn_implementation="eager", # Will be patched
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# Patch all layers
|
||||
print(f"Patching attention layers...")
|
||||
print(f" - Prefill: {prefill_desc}")
|
||||
print(f" - Decode: FlashAttention")
|
||||
for idx, layer in enumerate(model.model.layers):
|
||||
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
|
||||
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
|
||||
layer.self_attn, type(layer.self_attn)
|
||||
)
|
||||
|
||||
total_score = 0.0
|
||||
results = []
|
||||
|
||||
for sample in samples:
|
||||
idx = sample["_idx"]
|
||||
prompt = sample["input"]
|
||||
expected = sample["outputs"]
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
num_tokens = inputs["input_ids"].shape[1]
|
||||
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
|
||||
print(f"Expected: {expected}")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model.generate(
|
||||
inputs["input_ids"],
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
|
||||
score = string_match_all(output_text, expected)
|
||||
total_score += score
|
||||
|
||||
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
|
||||
print(f"Output: '{output_text[:100]}...'")
|
||||
print(f"Result: {status} (score={score:.2f})")
|
||||
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
|
||||
|
||||
avg_score = total_score / len(samples)
|
||||
passed = sum(1 for r in results if r["passed"])
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return avg_score >= 0.5
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
|
||||
)
|
||||
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
|
||||
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
|
||||
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
|
||||
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
|
||||
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
|
||||
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
|
||||
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=50)
|
||||
# Keep old option for backwards compatibility
|
||||
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model.replace("~", "/home/zijie")
|
||||
|
||||
# Handle deprecated --no-xattn option
|
||||
prefill_method = args.prefill_method
|
||||
if args.no_xattn:
|
||||
prefill_method = "flash"
|
||||
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
|
||||
|
||||
if args.sample_id is not None:
|
||||
sample_ids = [args.sample_id]
|
||||
elif args.sample_ids:
|
||||
sample_ids = [int(x) for x in args.sample_ids.split(",")]
|
||||
else:
|
||||
sample_ids = [0]
|
||||
|
||||
# Check BSA availability if using xattn
|
||||
if prefill_method == "xattn":
|
||||
try:
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
print("✓ BSA (Block Sparse Attention) available")
|
||||
except ImportError:
|
||||
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
|
||||
sys.exit(1)
|
||||
|
||||
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
|
||||
print("\ntest_xattn_bsa: PASSED")
|
||||
else:
|
||||
print("\ntest_xattn_bsa: FAILED")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user