Files
nano-vllm/tests/test_xattn_bsa.py
2026-01-23 03:36:12 +08:00

335 lines
13 KiB
Python

"""
Test XAttention + BSA with RULER benchmark data.
Tests XAttention sparse attention correctness using RULER NIAH task.
Attention methods:
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
- Decode: FlashAttention (always, since q_len=1)
Usage (in compass conda env with BSA available):
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
# Test with XAttention + BSA for prefill (default)
python tests/test_xattn_bsa.py --prefill-method xattn
# Test with FlashAttention for prefill (baseline)
python tests/test_xattn_bsa.py --prefill-method flash
# Test specific sample(s)
python tests/test_xattn_bsa.py --sample-id 0
python tests/test_xattn_bsa.py --sample-ids 0,1,2
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
and new `past_key_values` API).
"""
import argparse
import json
import sys
import torch
from pathlib import Path
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.cache_utils import DynamicCache
from nanovllm.ops.xattn import xattn_estimate
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
# ============================================================
# XAttention + BSA Functions
# ============================================================
def expand_kv_for_gqa(key_states, value_states, num_heads):
"""Expand KV for Grouped Query Attention."""
num_kv_heads = key_states.shape[1]
if num_heads == num_kv_heads:
return key_states, value_states
num_groups = num_heads // num_kv_heads
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
"""Standard FlashAttention."""
from flash_attn import flash_attn_func
q = query_states.transpose(1, 2)
k = key_states.transpose(1, 2)
v = value_states.transpose(1, 2)
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
"""XAttention + BSA sparse attention."""
from block_sparse_attn import block_sparse_attn_func
batch_size, num_heads, q_len, head_dim = query_states.shape
k_len = key_states.shape[2]
_, mask = xattn_estimate(
query_states, key_states,
chunk_size=16384, block_size=128, threshold=threshold,
use_triton=True, causal=True,
)
q_block_num = (q_len + 127) // 128
k_block_num = (k_len + 127) // 128
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
__import__('pdb').set_trace()
output = block_sparse_attn_func(
q, k, v,
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
torch.ones(num_heads, dtype=torch.int32, device=q.device),
None,
mask[:, :, :q_block_num, :k_block_num].contiguous(),
q_len, k_len,
p_dropout=0.0, deterministic=True, is_causal=True,
)
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
DEBUG = False # Set to True to enable debugging
def create_patched_forward(prefill_method="xattn", threshold=0.9):
"""Create patched forward with configurable prefill method.
Args:
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
Note:
- Prefill (q_len > 1): Uses specified prefill_method
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
"""
call_count = [0] # Mutable to track calls across layers
def patched_forward(
self,
hidden_states,
position_embeddings=None,
attention_mask=None,
past_key_value=None, # Old API (transformers < 4.57)
past_key_values=None, # New API (transformers >= 4.57)
cache_position=None,
**kwargs
):
# Handle both old and new transformers API
kv_cache = past_key_values if past_key_values is not None else past_key_value
bsz, q_len, _ = hidden_states.size()
num_heads = self.config.num_attention_heads
num_kv_heads = self.config.num_key_value_heads
head_dim = self.head_dim
# Compute Q, K, V projections
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
# Apply rotary position embedding
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
# Handle KV cache
if kv_cache is not None:
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = kv_cache.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# Expand KV for GQA
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
# Debug output
if DEBUG and self.layer_idx == 0:
call_count[0] += 1
if call_count[0] <= 5:
phase = "prefill" if q_len > 1 else "decode"
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
print(f" kv_cache is None: {kv_cache is None}")
# Choose attention method:
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
# - Decode (q_len = 1): Always use FlashAttention
is_prefill = q_len > 1
if is_prefill and prefill_method == "xattn":
# Prefill with XAttention + BSA (sparse)
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
else:
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
# Note: For decode (q_len=1), causal=False since single query attends to all KV
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
return attn_output, None
return patched_forward
# ============================================================
# Data & Evaluation
# ============================================================
def load_samples(filepath, indices=None):
"""Load samples from JSONL file."""
samples = []
with open(filepath) as f:
for i, line in enumerate(f):
if indices is None or i in indices:
sample = json.loads(line)
sample["_idx"] = i
samples.append(sample)
return samples
def string_match_all(output_text, expected_list):
"""RULER metric: fraction of expected values found in output."""
output_lower = output_text.lower().replace('\n', ' ')
if not expected_list:
return 1.0
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
# ============================================================
# Test
# ============================================================
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
"""Test attention methods using RULER data.
Args:
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
"""
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
print("=" * 60)
print("RULER NIAH Attention Test")
print("=" * 60)
print(f"Data: {data_file}")
print(f"Samples: {sample_ids}")
print(f"Prefill method: {prefill_desc}")
print(f"Decode method: FlashAttention (always)")
if prefill_method == "xattn":
print(f"XAttention threshold: {threshold}")
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
if not samples:
print("No samples found!")
return False
print(f"Loaded {len(samples)} samples")
# Load model
print(f"\nLoading model: {model_path}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, torch_dtype=torch.float16, device_map="cuda",
attn_implementation="eager", # Will be patched
)
model.eval()
# Patch all layers
print(f"Patching attention layers...")
print(f" - Prefill: {prefill_desc}")
print(f" - Decode: FlashAttention")
for idx, layer in enumerate(model.model.layers):
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
layer.self_attn, type(layer.self_attn)
)
total_score = 0.0
results = []
for sample in samples:
idx = sample["_idx"]
prompt = sample["input"]
expected = sample["outputs"]
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
num_tokens = inputs["input_ids"].shape[1]
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
print(f"Expected: {expected}")
with torch.no_grad():
output = model.generate(
inputs["input_ids"],
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=tokenizer.eos_token_id,
)
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
score = string_match_all(output_text, expected)
total_score += score
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
print(f"Output: '{output_text[:100]}...'")
print(f"Result: {status} (score={score:.2f})")
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
avg_score = total_score / len(samples)
passed = sum(1 for r in results if r["passed"])
print(f"\n{'='*60}")
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
print(f"{'='*60}")
return avg_score >= 0.5
def main():
parser = argparse.ArgumentParser(
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
)
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
parser.add_argument("--max-new-tokens", type=int, default=50)
# Keep old option for backwards compatibility
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
args = parser.parse_args()
model_path = args.model.replace("~", "/home/zijie")
# Handle deprecated --no-xattn option
prefill_method = args.prefill_method
if args.no_xattn:
prefill_method = "flash"
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
if args.sample_id is not None:
sample_ids = [args.sample_id]
elif args.sample_ids:
sample_ids = [int(x) for x in args.sample_ids.split(",")]
else:
sample_ids = [0]
# Check BSA availability if using xattn
if prefill_method == "xattn":
try:
from block_sparse_attn import block_sparse_attn_func
print("✓ BSA (Block Sparse Attention) available")
except ImportError:
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
sys.exit(1)
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
print("\ntest_xattn_bsa: PASSED")
else:
print("\ntest_xattn_bsa: FAILED")
sys.exit(1)
if __name__ == "__main__":
main()