🗑️ chore: remove redundant XAttention test files
Remove 6 obsolete test files: - test_xattn_bsa.py - XAttn+BSA integration (covered by test_ruler) - test_xattn_chunked.py - duplicate of test_xattn_estimate_chunked - test_xattn_estimate_chunked.py - chunked prefill validation - test_xattn_kernels.py - Triton kernel unit tests - test_xattn_kv_chunking_batch.py - batch KV chunking validation - test_chunk_attention_graph.py - superseded by graph_reuse version Retained: test_xattn_estimate_alignment.py (critical kernel validation) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,151 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test: Pre-allocated chunk pair graphs for block sparse attention.
|
||||
|
||||
Each (Q_chunk, K_chunk) pair has its own captured CUDA graph.
|
||||
Zero copy_() during replay - all data pre-filled.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph.py
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkAttentionGraph:
|
||||
"""Container for a captured chunk attention graph."""
|
||||
graph: torch.cuda.CUDAGraph
|
||||
static_q: torch.Tensor
|
||||
static_k: torch.Tensor
|
||||
static_v: torch.Tensor
|
||||
static_output: torch.Tensor
|
||||
static_lse: torch.Tensor
|
||||
causal: bool
|
||||
|
||||
|
||||
def capture_chunk_attention_graph(
|
||||
chunk_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
scale: float,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
causal: bool = False,
|
||||
) -> ChunkAttentionGraph:
|
||||
"""Capture a CUDA graph for single chunk attention."""
|
||||
static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device)
|
||||
static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
|
||||
static_q.normal_()
|
||||
static_k.normal_()
|
||||
static_v.normal_()
|
||||
|
||||
# Warmup
|
||||
with torch.inference_mode():
|
||||
for _ in range(3):
|
||||
_ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Capture
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.inference_mode():
|
||||
with torch.cuda.graph(graph):
|
||||
static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return ChunkAttentionGraph(
|
||||
graph=graph,
|
||||
static_q=static_q,
|
||||
static_k=static_k,
|
||||
static_v=static_v,
|
||||
static_output=static_output,
|
||||
static_lse=static_lse,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16
|
||||
|
||||
chunk_size = 64
|
||||
num_chunks = 4
|
||||
num_heads = 8
|
||||
num_kv_heads = 8
|
||||
head_dim = 64
|
||||
scale = 1.0 / (head_dim ** 0.5)
|
||||
seq_len = chunk_size * num_chunks
|
||||
|
||||
print(f"Device: {torch.cuda.get_device_name()}")
|
||||
print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}")
|
||||
print(f"Total graphs: {num_chunks * (num_chunks + 1) // 2}")
|
||||
|
||||
# Test data
|
||||
full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device)
|
||||
full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device)
|
||||
|
||||
# Reference
|
||||
with torch.inference_mode():
|
||||
full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True)
|
||||
|
||||
# Capture all graphs
|
||||
graphs: List[List[Optional[ChunkAttentionGraph]]] = [[None] * num_chunks for _ in range(num_chunks)]
|
||||
for q_idx in range(num_chunks):
|
||||
for k_idx in range(q_idx + 1):
|
||||
graphs[q_idx][k_idx] = capture_chunk_attention_graph(
|
||||
chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype,
|
||||
causal=(k_idx == q_idx)
|
||||
)
|
||||
print("All graphs captured")
|
||||
|
||||
# Pre-fill static tensors
|
||||
for q_idx in range(num_chunks):
|
||||
for k_idx in range(q_idx + 1):
|
||||
g = graphs[q_idx][k_idx]
|
||||
g.static_q.copy_(full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size])
|
||||
g.static_k.copy_(full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
|
||||
g.static_v.copy_(full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size])
|
||||
print("Static tensors pre-filled")
|
||||
|
||||
# Replay and merge
|
||||
chunked_output = torch.zeros_like(full_output)
|
||||
for q_idx in range(num_chunks):
|
||||
acc_out, acc_lse = None, None
|
||||
for k_idx in range(q_idx + 1):
|
||||
g = graphs[q_idx][k_idx]
|
||||
g.graph.replay()
|
||||
out, lse = g.static_output.clone(), g.static_lse.clone()
|
||||
if acc_out is None:
|
||||
acc_out, acc_lse = out, lse
|
||||
else:
|
||||
with torch.inference_mode():
|
||||
acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse)
|
||||
chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Compare
|
||||
all_pass = True
|
||||
for q_idx in range(num_chunks):
|
||||
s, e = q_idx * chunk_size, (q_idx + 1) * chunk_size
|
||||
diff = (full_output[:, s:e] - chunked_output[:, s:e]).abs().max().item()
|
||||
status = "✅" if diff < 1e-2 else "❌"
|
||||
print(f"Q[{q_idx}]: max_diff={diff:.2e} {status}")
|
||||
if diff >= 1e-2:
|
||||
all_pass = False
|
||||
|
||||
print("✅ PASSED" if all_pass else "❌ FAILED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,334 +0,0 @@
|
||||
"""
|
||||
Test XAttention + BSA with RULER benchmark data.
|
||||
|
||||
Tests XAttention sparse attention correctness using RULER NIAH task.
|
||||
|
||||
Attention methods:
|
||||
- Prefill: XAttention + BSA (sparse) or FlashAttention (dense)
|
||||
- Decode: FlashAttention (always, since q_len=1)
|
||||
|
||||
Usage (in compass conda env with BSA available):
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct
|
||||
|
||||
# Test with XAttention + BSA for prefill (default)
|
||||
python tests/test_xattn_bsa.py --prefill-method xattn
|
||||
|
||||
# Test with FlashAttention for prefill (baseline)
|
||||
python tests/test_xattn_bsa.py --prefill-method flash
|
||||
|
||||
# Test specific sample(s)
|
||||
python tests/test_xattn_bsa.py --sample-id 0
|
||||
python tests/test_xattn_bsa.py --sample-ids 0,1,2
|
||||
|
||||
Note: Compatible with transformers 4.53+ (handles both old `past_key_value`
|
||||
and new `past_key_values` API).
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.cache_utils import DynamicCache
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||
|
||||
|
||||
# ============================================================
|
||||
# XAttention + BSA Functions
|
||||
# ============================================================
|
||||
|
||||
def expand_kv_for_gqa(key_states, value_states, num_heads):
|
||||
"""Expand KV for Grouped Query Attention."""
|
||||
num_kv_heads = key_states.shape[1]
|
||||
if num_heads == num_kv_heads:
|
||||
return key_states, value_states
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1)
|
||||
|
||||
|
||||
def flash_attention_forward(query_states, key_states, value_states, is_causal=True):
|
||||
"""Standard FlashAttention."""
|
||||
from flash_attn import flash_attn_func
|
||||
q = query_states.transpose(1, 2)
|
||||
k = key_states.transpose(1, 2)
|
||||
v = value_states.transpose(1, 2)
|
||||
return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2)
|
||||
|
||||
|
||||
def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9):
|
||||
"""XAttention + BSA sparse attention."""
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
|
||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||
k_len = key_states.shape[2]
|
||||
|
||||
_, mask = xattn_estimate(
|
||||
query_states, key_states,
|
||||
chunk_size=16384, block_size=128, threshold=threshold,
|
||||
use_triton=True, causal=True,
|
||||
)
|
||||
|
||||
q_block_num = (q_len + 127) // 128
|
||||
k_block_num = (k_len + 127) // 128
|
||||
|
||||
q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim)
|
||||
k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
||||
v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim)
|
||||
|
||||
__import__('pdb').set_trace()
|
||||
|
||||
output = block_sparse_attn_func(
|
||||
q, k, v,
|
||||
torch.tensor([0, q_len], dtype=torch.int32, device=q.device),
|
||||
torch.tensor([0, k_len], dtype=torch.int32, device=k.device),
|
||||
torch.ones(num_heads, dtype=torch.int32, device=q.device),
|
||||
None,
|
||||
mask[:, :, :q_block_num, :k_block_num].contiguous(),
|
||||
q_len, k_len,
|
||||
p_dropout=0.0, deterministic=True, is_causal=True,
|
||||
)
|
||||
return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
|
||||
DEBUG = False # Set to True to enable debugging
|
||||
|
||||
def create_patched_forward(prefill_method="xattn", threshold=0.9):
|
||||
"""Create patched forward with configurable prefill method.
|
||||
|
||||
Args:
|
||||
prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense)
|
||||
threshold: XAttention threshold for block selection (only used when prefill_method="xattn")
|
||||
|
||||
Note:
|
||||
- Prefill (q_len > 1): Uses specified prefill_method
|
||||
- Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query)
|
||||
"""
|
||||
call_count = [0] # Mutable to track calls across layers
|
||||
|
||||
def patched_forward(
|
||||
self,
|
||||
hidden_states,
|
||||
position_embeddings=None,
|
||||
attention_mask=None,
|
||||
past_key_value=None, # Old API (transformers < 4.57)
|
||||
past_key_values=None, # New API (transformers >= 4.57)
|
||||
cache_position=None,
|
||||
**kwargs
|
||||
):
|
||||
# Handle both old and new transformers API
|
||||
kv_cache = past_key_values if past_key_values is not None else past_key_value
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
num_heads = self.config.num_attention_heads
|
||||
num_kv_heads = self.config.num_key_value_heads
|
||||
head_dim = self.head_dim
|
||||
|
||||
# Compute Q, K, V projections
|
||||
query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
||||
value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Apply rotary position embedding
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
# Handle KV cache
|
||||
if kv_cache is not None:
|
||||
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
||||
key_states, value_states = kv_cache.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs
|
||||
)
|
||||
|
||||
# Expand KV for GQA
|
||||
key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads)
|
||||
|
||||
# Debug output
|
||||
if DEBUG and self.layer_idx == 0:
|
||||
call_count[0] += 1
|
||||
if call_count[0] <= 5:
|
||||
phase = "prefill" if q_len > 1 else "decode"
|
||||
print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}")
|
||||
print(f" kv_cache is None: {kv_cache is None}")
|
||||
|
||||
# Choose attention method:
|
||||
# - Prefill (q_len > 1): Use prefill_method (xattn or flash)
|
||||
# - Decode (q_len = 1): Always use FlashAttention
|
||||
is_prefill = q_len > 1
|
||||
|
||||
if is_prefill and prefill_method == "xattn":
|
||||
# Prefill with XAttention + BSA (sparse)
|
||||
attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold)
|
||||
else:
|
||||
# Prefill with FlashAttention (dense) OR Decode (always FlashAttention)
|
||||
# Note: For decode (q_len=1), causal=False since single query attends to all KV
|
||||
attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill)
|
||||
|
||||
attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1))
|
||||
return attn_output, None
|
||||
|
||||
return patched_forward
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Data & Evaluation
|
||||
# ============================================================
|
||||
|
||||
def load_samples(filepath, indices=None):
|
||||
"""Load samples from JSONL file."""
|
||||
samples = []
|
||||
with open(filepath) as f:
|
||||
for i, line in enumerate(f):
|
||||
if indices is None or i in indices:
|
||||
sample = json.loads(line)
|
||||
sample["_idx"] = i
|
||||
samples.append(sample)
|
||||
return samples
|
||||
|
||||
|
||||
def string_match_all(output_text, expected_list):
|
||||
"""RULER metric: fraction of expected values found in output."""
|
||||
output_lower = output_text.lower().replace('\n', ' ')
|
||||
if not expected_list:
|
||||
return 1.0
|
||||
return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Test
|
||||
# ============================================================
|
||||
|
||||
def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50):
|
||||
"""Test attention methods using RULER data.
|
||||
|
||||
Args:
|
||||
prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention
|
||||
"""
|
||||
prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)"
|
||||
|
||||
print("=" * 60)
|
||||
print("RULER NIAH Attention Test")
|
||||
print("=" * 60)
|
||||
print(f"Data: {data_file}")
|
||||
print(f"Samples: {sample_ids}")
|
||||
print(f"Prefill method: {prefill_desc}")
|
||||
print(f"Decode method: FlashAttention (always)")
|
||||
if prefill_method == "xattn":
|
||||
print(f"XAttention threshold: {threshold}")
|
||||
|
||||
samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None)
|
||||
if not samples:
|
||||
print("No samples found!")
|
||||
return False
|
||||
print(f"Loaded {len(samples)} samples")
|
||||
|
||||
# Load model
|
||||
print(f"\nLoading model: {model_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path, torch_dtype=torch.float16, device_map="cuda",
|
||||
attn_implementation="eager", # Will be patched
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# Patch all layers
|
||||
print(f"Patching attention layers...")
|
||||
print(f" - Prefill: {prefill_desc}")
|
||||
print(f" - Decode: FlashAttention")
|
||||
for idx, layer in enumerate(model.model.layers):
|
||||
layer.self_attn.layer_idx = idx # Ensure layer_idx is set
|
||||
layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__(
|
||||
layer.self_attn, type(layer.self_attn)
|
||||
)
|
||||
|
||||
total_score = 0.0
|
||||
results = []
|
||||
|
||||
for sample in samples:
|
||||
idx = sample["_idx"]
|
||||
prompt = sample["input"]
|
||||
expected = sample["outputs"]
|
||||
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
|
||||
num_tokens = inputs["input_ids"].shape[1]
|
||||
print(f"\n--- Sample {idx} ({num_tokens} tokens) ---")
|
||||
print(f"Expected: {expected}")
|
||||
|
||||
with torch.no_grad():
|
||||
output = model.generate(
|
||||
inputs["input_ids"],
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=False,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True)
|
||||
score = string_match_all(output_text, expected)
|
||||
total_score += score
|
||||
|
||||
status = "✓ PASS" if score >= 0.5 else "✗ FAIL"
|
||||
print(f"Output: '{output_text[:100]}...'")
|
||||
print(f"Result: {status} (score={score:.2f})")
|
||||
results.append({"idx": idx, "score": score, "passed": score >= 0.5})
|
||||
|
||||
avg_score = total_score / len(samples)
|
||||
passed = sum(1 for r in results if r["passed"])
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
return avg_score >= 0.5
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark"
|
||||
)
|
||||
parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct")
|
||||
parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl")
|
||||
parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index")
|
||||
parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)")
|
||||
parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn",
|
||||
help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)")
|
||||
parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=50)
|
||||
# Keep old option for backwards compatibility
|
||||
parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead")
|
||||
args = parser.parse_args()
|
||||
|
||||
model_path = args.model.replace("~", "/home/zijie")
|
||||
|
||||
# Handle deprecated --no-xattn option
|
||||
prefill_method = args.prefill_method
|
||||
if args.no_xattn:
|
||||
prefill_method = "flash"
|
||||
print("Warning: --no-xattn is deprecated, use --prefill-method flash instead")
|
||||
|
||||
if args.sample_id is not None:
|
||||
sample_ids = [args.sample_id]
|
||||
elif args.sample_ids:
|
||||
sample_ids = [int(x) for x in args.sample_ids.split(",")]
|
||||
else:
|
||||
sample_ids = [0]
|
||||
|
||||
# Check BSA availability if using xattn
|
||||
if prefill_method == "xattn":
|
||||
try:
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
print("✓ BSA (Block Sparse Attention) available")
|
||||
except ImportError:
|
||||
print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash")
|
||||
sys.exit(1)
|
||||
|
||||
if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens):
|
||||
print("\ntest_xattn_bsa: PASSED")
|
||||
else:
|
||||
print("\ntest_xattn_bsa: FAILED")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,259 +0,0 @@
|
||||
"""
|
||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation.
|
||||
|
||||
Uses real QKV data captured from model inference.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import torch
|
||||
import warnings
|
||||
|
||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
BLOCK_SIZE = 64
|
||||
STRIDE = 4
|
||||
THRESHOLD = 0.9
|
||||
CHUNK_SIZE = 4096
|
||||
|
||||
# Default QKV data directory (relative to project root)
|
||||
DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache")
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# ============================================================
|
||||
|
||||
def load_qkv(path):
|
||||
"""Load saved QKV data."""
|
||||
data = torch.load(path, map_location="cpu", weights_only=False)
|
||||
print(f"Loaded: {path}")
|
||||
print(f" Query shape: {data['query'].shape}")
|
||||
print(f" Key shape: {data['key'].shape}")
|
||||
print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}")
|
||||
return data
|
||||
|
||||
|
||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||
"""Compare two masks and report differences."""
|
||||
if mask1.shape != mask2.shape:
|
||||
print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||
return False
|
||||
|
||||
diff = (mask1 != mask2).sum().item()
|
||||
total = mask1.numel()
|
||||
match_rate = (total - diff) / total * 100
|
||||
|
||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||
|
||||
if diff > 0:
|
||||
diff_indices = torch.where(mask1 != mask2)
|
||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||
|
||||
return diff == 0
|
||||
|
||||
|
||||
def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size):
|
||||
"""
|
||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||
This simulates how chunked prefill should be used in practice.
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query.shape
|
||||
_, _, k_len, _ = key.shape
|
||||
|
||||
q_block_num = (q_len + block_size - 1) // block_size
|
||||
k_block_num = (k_len + block_size - 1) // block_size
|
||||
|
||||
# If Q fits in one chunk, call directly
|
||||
if q_len <= chunk_size:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
return xattn_estimate_chunked(
|
||||
query, key,
|
||||
q_start_pos=q_start_pos,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# External chunking: split Q and call for each chunk
|
||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||
print(f" External chunking: {num_q_chunks} chunks")
|
||||
|
||||
combined_attn_sum = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=query.dtype, device=query.device
|
||||
)
|
||||
combined_mask = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=torch.bool, device=query.device
|
||||
)
|
||||
|
||||
q_block_offset = 0
|
||||
for q_chunk_idx in range(num_q_chunks):
|
||||
q_chunk_start = q_chunk_idx * chunk_size
|
||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||
|
||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||
|
||||
# For causal attention, K accumulates up to current Q position
|
||||
k_end = q_start_pos + q_chunk_end
|
||||
k_chunk = key[:, :, :k_end, :]
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||
q_chunk, k_chunk,
|
||||
q_start_pos=q_start_pos + q_chunk_start,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Place chunk results into combined output
|
||||
chunk_q_blocks = mask_chunk.shape[2]
|
||||
chunk_k_blocks = mask_chunk.shape[3]
|
||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||
q_block_offset += chunk_q_blocks
|
||||
|
||||
return combined_attn_sum, combined_mask
|
||||
|
||||
|
||||
def test_single_qkv(qkv_path):
|
||||
"""Test a single QKV file."""
|
||||
data = load_qkv(qkv_path)
|
||||
query = data["query"].cuda().to(torch.bfloat16)
|
||||
key = data["key"].cuda().to(torch.bfloat16)
|
||||
|
||||
seq_len = query.shape[2]
|
||||
print(f"\nTesting with seq_len={seq_len}")
|
||||
print("=" * 60)
|
||||
|
||||
# Run standard xattn_estimate
|
||||
print("[1] Running standard xattn_estimate...")
|
||||
try:
|
||||
attn_sum_std, mask_std = xattn_estimate(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
use_triton=True,
|
||||
)
|
||||
print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||
try:
|
||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||
query, key,
|
||||
q_start_pos=0,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
)
|
||||
print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Compare results
|
||||
print("[3] Comparing results...")
|
||||
chunked_q_blocks = mask_chunked.shape[2]
|
||||
chunked_k_blocks = mask_chunked.shape[3]
|
||||
|
||||
# Extract comparable region from standard mask
|
||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
|
||||
# Compare masks
|
||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||
|
||||
# Compare attn_sums
|
||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||
else:
|
||||
print(f" Attn sum shape mismatch")
|
||||
|
||||
# Clean up GPU memory
|
||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return masks_match
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked")
|
||||
parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR,
|
||||
help="Directory containing QKV files")
|
||||
args = parser.parse_args()
|
||||
|
||||
# QKV files to test
|
||||
qkv_files = [
|
||||
os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K
|
||||
os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K
|
||||
os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K
|
||||
os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K
|
||||
os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K
|
||||
]
|
||||
|
||||
available_files = [p for p in qkv_files if os.path.exists(p)]
|
||||
|
||||
if not available_files:
|
||||
print(f"No QKV file found in {args.qkv_dir}.")
|
||||
print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Found {len(available_files)} QKV files to test")
|
||||
print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})")
|
||||
print(f"Using Triton kernels")
|
||||
|
||||
all_passed = True
|
||||
results = []
|
||||
|
||||
for qkv_path in available_files:
|
||||
passed = test_single_qkv(qkv_path)
|
||||
seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", ""))
|
||||
results.append((seq_len, passed))
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for seq_len, passed in results:
|
||||
status = "PASSED" if passed else "FAILED"
|
||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||
print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}")
|
||||
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("test_xattn_chunked: PASSED")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("test_xattn_chunked: FAILED")
|
||||
sys.exit(1)
|
||||
@@ -1,244 +0,0 @@
|
||||
"""
|
||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
||||
|
||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask
|
||||
as standard estimation. This ensures the chunked version can be used in
|
||||
chunked prefill scenarios without accuracy loss.
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_estimate_chunked.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
import torch
|
||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
# ============================================================
|
||||
|
||||
# Configuration for xattn_estimate_chunked consistency test.
|
||||
# Key requirements for 100% match:
|
||||
# 1. Use matching chunk_size for both standard and chunked versions
|
||||
# 2. Use same random seed for reproducibility
|
||||
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
|
||||
# floating point precision in cumulative sum calculations.
|
||||
BLOCK_SIZE = 64
|
||||
STRIDE = 4
|
||||
THRESHOLD = 0.9
|
||||
CHUNK_SIZE = 4096 # External chunking size
|
||||
|
||||
# Test sequence lengths
|
||||
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# ============================================================
|
||||
|
||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
||||
"""Compare two masks and report differences."""
|
||||
if mask1.shape != mask2.shape:
|
||||
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
||||
return False
|
||||
|
||||
diff = (mask1 != mask2).sum().item()
|
||||
total = mask1.numel()
|
||||
match_rate = (total - diff) / total * 100
|
||||
|
||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
||||
|
||||
if diff > 0:
|
||||
diff_indices = torch.where(mask1 != mask2)
|
||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
||||
|
||||
return diff == 0
|
||||
|
||||
|
||||
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
|
||||
"""
|
||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
||||
This simulates how chunked prefill should be used in practice.
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query.shape
|
||||
_, _, k_len, _ = key.shape
|
||||
|
||||
q_block_num = (q_len + block_size - 1) // block_size
|
||||
k_block_num = (k_len + block_size - 1) // block_size
|
||||
|
||||
# If Q fits in one chunk, call directly
|
||||
if q_len <= chunk_size:
|
||||
return xattn_estimate_chunked(
|
||||
query, key,
|
||||
q_start_pos=0,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# External chunking: split Q and call for each chunk
|
||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
||||
print(f" External chunking: {num_q_chunks} chunks")
|
||||
|
||||
combined_attn_sum = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=query.dtype, device=query.device
|
||||
)
|
||||
combined_mask = torch.zeros(
|
||||
batch_size, num_heads, q_block_num, k_block_num,
|
||||
dtype=torch.bool, device=query.device
|
||||
)
|
||||
|
||||
q_block_offset = 0
|
||||
for q_chunk_idx in range(num_q_chunks):
|
||||
q_chunk_start = q_chunk_idx * chunk_size
|
||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
||||
|
||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
||||
|
||||
# For causal attention, K accumulates up to current Q position
|
||||
# q_start_pos=0 means Q starts at position 0 in the full sequence
|
||||
# K is [0, q_chunk_end) for causal attention
|
||||
k_end = q_chunk_end
|
||||
k_chunk = key[:, :, :k_end, :]
|
||||
|
||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
||||
q_chunk, k_chunk,
|
||||
q_start_pos=q_chunk_start,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
threshold=threshold,
|
||||
use_triton=True,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# Place chunk results into combined output
|
||||
chunk_q_blocks = mask_chunk.shape[2]
|
||||
chunk_k_blocks = mask_chunk.shape[3]
|
||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
||||
q_block_offset += chunk_q_blocks
|
||||
|
||||
return combined_attn_sum, combined_mask
|
||||
|
||||
|
||||
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
|
||||
"""Test a single sequence length."""
|
||||
print(f"\nTesting seq_len={seq_len}")
|
||||
print("=" * 60)
|
||||
|
||||
# Generate random Q/K
|
||||
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Run standard xattn_estimate
|
||||
print("[1] Running standard xattn_estimate...")
|
||||
try:
|
||||
attn_sum_std, mask_std = xattn_estimate(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
use_triton=True,
|
||||
causal=True,
|
||||
)
|
||||
density_std = mask_std.float().mean().item()
|
||||
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
||||
try:
|
||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
||||
query, key,
|
||||
block_size=BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
)
|
||||
density_chunked = mask_chunked.float().mean().item()
|
||||
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
|
||||
except Exception as e:
|
||||
print(f" ERROR: {e}")
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
# Compare results
|
||||
print("[3] Comparing results...")
|
||||
chunked_q_blocks = mask_chunked.shape[2]
|
||||
chunked_k_blocks = mask_chunked.shape[3]
|
||||
|
||||
# Extract comparable region from standard mask
|
||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
|
||||
# Compare masks
|
||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
||||
|
||||
# Compare attn_sums
|
||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
||||
else:
|
||||
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
|
||||
|
||||
# Clean up GPU memory
|
||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return masks_match
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Test
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("XAttention Chunked vs Standard Test")
|
||||
print("=" * 60)
|
||||
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
|
||||
print(f"External chunk_size={CHUNK_SIZE}")
|
||||
print()
|
||||
|
||||
# Check CUDA availability
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available!")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
print("✓ xattn_estimate imported")
|
||||
print("✓ xattn_estimate_chunked imported")
|
||||
|
||||
# Run tests
|
||||
all_passed = True
|
||||
results = []
|
||||
|
||||
for seq_len in TEST_SEQ_LENS:
|
||||
passed = test_single_seq_len(seq_len)
|
||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
||||
results.append((seq_len, chunks, passed))
|
||||
if not passed:
|
||||
all_passed = False
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
for seq_len, chunks, passed in results:
|
||||
status = "PASSED" if passed else "FAILED"
|
||||
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
|
||||
|
||||
print("=" * 60)
|
||||
if all_passed:
|
||||
print("ALL TESTS PASSED!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("SOME TESTS FAILED!")
|
||||
sys.exit(1)
|
||||
@@ -1,132 +0,0 @@
|
||||
"""
|
||||
Test: XAttention Triton kernels
|
||||
|
||||
演示 XAttention 的两个核心 Triton kernel:
|
||||
1. flat_group_gemm_fuse_reshape: 计算 stride reshape 后的 attention scores (反对角线求和)
|
||||
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
|
||||
|
||||
数据流:
|
||||
Q [batch, heads, q_len, head_dim]
|
||||
K [batch, heads, kv_len, head_dim]
|
||||
↓ flat_group_gemm_fuse_reshape
|
||||
attn_scores [batch, heads, q_len/stride, kv_len/stride]
|
||||
↓ softmax_fuse_block_sum
|
||||
block_sums [batch, heads, q_blocks, k_blocks]
|
||||
"""
|
||||
import torch
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
|
||||
|
||||
# ============================================================
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
|
||||
# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N
|
||||
# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512
|
||||
# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256
|
||||
q_len = 512
|
||||
kv_len = 2048
|
||||
head_dim = 128
|
||||
stride = 4
|
||||
block_size = 128 # softmax block size (in reshaped space)
|
||||
segment_size = 128 # Triton kernel 要求 segment_size >= block_size
|
||||
|
||||
# ============================================================
|
||||
# 构造输入: 偶数位置=1, 奇数位置=2
|
||||
# ============================================================
|
||||
|
||||
Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
|
||||
for i in range(q_len):
|
||||
if i % 2 == 0:
|
||||
Q[0, 0, i, :] = 1 * (i // stride + 1)
|
||||
else:
|
||||
Q[0, 0, i, :] = 2 * (i // stride + 1)
|
||||
|
||||
for i in range(kv_len):
|
||||
if i % 2 == 0:
|
||||
K[0, 0, i, :] = 1
|
||||
else:
|
||||
K[0, 0, i, :] = 2
|
||||
|
||||
# ============================================================
|
||||
# Step 1: flat_group_gemm_fuse_reshape (chunked along K)
|
||||
# ============================================================
|
||||
|
||||
q_reshaped_len = q_len // stride # 128
|
||||
kv_reshaped_len = kv_len // stride # 512
|
||||
|
||||
# 将 K 沿着长度维度分成多个 chunk
|
||||
k_chunk_size = 512 # 每个 chunk 512 tokens
|
||||
num_k_chunks = kv_len // k_chunk_size # 4 chunks
|
||||
|
||||
attn_scores_list = []
|
||||
for k_chunk_idx in range(num_k_chunks):
|
||||
k_start = k_chunk_idx * k_chunk_size
|
||||
k_end = k_start + k_chunk_size
|
||||
K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim]
|
||||
|
||||
# 对每个 K chunk 调用 flat_group_gemm_fuse_reshape
|
||||
# 输出: [batch, heads, q_len/stride, k_chunk_size/stride]
|
||||
attn_chunk = flat_group_gemm_fuse_reshape(
|
||||
Q, K_chunk, stride,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
is_causal=True
|
||||
)
|
||||
|
||||
__import__('pdb').set_trace()
|
||||
|
||||
attn_scores_list.append(attn_chunk)
|
||||
|
||||
# 拼接所有 K chunks 的结果
|
||||
# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride]
|
||||
# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len]
|
||||
attn_scores = torch.cat(attn_scores_list, dim=-1)
|
||||
|
||||
# 验证 shape: [batch, heads, q_len/stride, kv_len/stride]
|
||||
assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \
|
||||
f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})"
|
||||
|
||||
# 验证: 反对角线求和
|
||||
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
|
||||
# 反对角线有 stride/2 对,再乘以 head_dim
|
||||
expected_gemm = (2*1 + 1*2) * (stride // 2) * head_dim
|
||||
actual_gemm = attn_scores[0, 0, 0, 0].item()
|
||||
assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expected_gemm}"
|
||||
|
||||
# ============================================================
|
||||
# Step 2: softmax_fuse_block_sum
|
||||
# ============================================================
|
||||
|
||||
scale = 1.4426950408889634 # log2(e) for exp2
|
||||
|
||||
block_sums = softmax_fuse_block_sum(
|
||||
attn_scores,
|
||||
block_size,
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=q_reshaped_len,
|
||||
real_q_len=q_reshaped_len,
|
||||
scale=scale,
|
||||
is_causal=False
|
||||
)
|
||||
|
||||
# 验证 shape: [batch, heads, q_blocks, k_blocks]
|
||||
q_blocks = q_reshaped_len // block_size # 128 / 128 = 1
|
||||
k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4
|
||||
assert block_sums.shape == (1, 1, q_blocks, k_blocks), \
|
||||
f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})"
|
||||
|
||||
# 验证: 每个 block 的 softmax 结果求和
|
||||
# 所有 attn_scores 相同 → softmax 均匀分布
|
||||
# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len
|
||||
# 每个 Q block 有 block_size 行
|
||||
# block_sum = block_size * (block_size / kv_reshaped_len)
|
||||
expected_sum = block_size * block_size / kv_reshaped_len
|
||||
actual_sum = block_sums[0, 0, 0, 0].item()
|
||||
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"
|
||||
|
||||
print("test_xattn_kernels: PASSED")
|
||||
@@ -1,246 +0,0 @@
|
||||
"""
|
||||
Test: 批量验证 xattn_estimate 与 KV chunking kernels 的一致性
|
||||
|
||||
测试 results/kvcache 下所有保存的 QKV 数据
|
||||
|
||||
Usage:
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
||||
python tests/test_xattn_kv_chunking_batch.py
|
||||
"""
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
|
||||
import os
|
||||
import glob
|
||||
import torch
|
||||
import math
|
||||
from nanovllm.ops.xattn import (
|
||||
xattn_estimate,
|
||||
flat_group_gemm_fuse_reshape,
|
||||
softmax_compute_partial_stats,
|
||||
softmax_normalize_and_block_sum,
|
||||
merge_softmax_stats,
|
||||
find_blocks_chunked,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
DATA_DIR = "/home/zijie/Code/nano-vllm/results/kvcache"
|
||||
BSA_BLOCK_SIZE = 128
|
||||
CHUNK_SIZE = 16384
|
||||
|
||||
device = "cuda"
|
||||
|
||||
|
||||
def test_single_file(data_file: str) -> dict:
|
||||
"""测试单个 kvcache 文件"""
|
||||
data = torch.load(data_file, map_location="cpu")
|
||||
Q = data["query"].to(device)
|
||||
K = data["key"].to(device)
|
||||
|
||||
batch_size, num_heads, seq_len, head_dim = Q.shape
|
||||
STRIDE = data["stride"]
|
||||
THRESHOLD = data["threshold"][0].item() if isinstance(data["threshold"], torch.Tensor) else data["threshold"]
|
||||
|
||||
# ========== xattn_estimate API ==========
|
||||
attn_sums_api, mask_api = xattn_estimate(
|
||||
Q, K,
|
||||
block_size=BSA_BLOCK_SIZE,
|
||||
stride=STRIDE,
|
||||
threshold=THRESHOLD,
|
||||
chunk_size=CHUNK_SIZE,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||
mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks]
|
||||
|
||||
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool))
|
||||
total_api = causal_mask.sum().item() * batch_size * num_heads
|
||||
selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_api = selected_api / total_api
|
||||
|
||||
# ========== 三阶段 KV Chunking ==========
|
||||
k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||
q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len
|
||||
q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE
|
||||
kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE
|
||||
|
||||
k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE
|
||||
q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE
|
||||
|
||||
reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||
reshaped_block_size = BSA_BLOCK_SIZE // STRIDE
|
||||
k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE
|
||||
k_reshaped_num_to_pad = k_num_to_pad // STRIDE
|
||||
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
|
||||
kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE
|
||||
|
||||
if k_num_to_pad > 0:
|
||||
K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0)
|
||||
else:
|
||||
K_padded = K
|
||||
|
||||
if q_num_to_pad > 0:
|
||||
Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0)
|
||||
else:
|
||||
Q_padded = Q
|
||||
|
||||
norm = 1.0
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm
|
||||
|
||||
simple_mask_list = []
|
||||
|
||||
for q_chunk_idx in range(q_chunk_num):
|
||||
q_start = q_chunk_idx * reshaped_chunk_size * STRIDE
|
||||
q_end = q_start + reshaped_chunk_size * STRIDE
|
||||
Q_chunk = Q_padded[:, :, q_start:q_end, :]
|
||||
|
||||
chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size
|
||||
chunk_end = chunk_start + reshaped_chunk_size
|
||||
|
||||
m_chunks = []
|
||||
l_chunks = []
|
||||
attn_weights_chunks = []
|
||||
|
||||
for kv_chunk_idx in range(kv_chunk_num):
|
||||
kv_start = kv_chunk_idx * CHUNK_SIZE
|
||||
kv_end = kv_start + CHUNK_SIZE
|
||||
K_chunk = K_padded[:, :, kv_start:kv_end, :]
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||
|
||||
attn_weights_kv = flat_group_gemm_fuse_reshape(
|
||||
Q_chunk, K_chunk, STRIDE,
|
||||
chunk_start=chunk_start,
|
||||
chunk_end=chunk_end,
|
||||
is_causal=False,
|
||||
)
|
||||
attn_weights_chunks.append(attn_weights_kv)
|
||||
|
||||
m_partial, l_partial = softmax_compute_partial_stats(
|
||||
attn_weights_kv,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
scale,
|
||||
chunk_start=chunk_start,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
m_chunks.append(m_partial)
|
||||
l_chunks.append(l_partial)
|
||||
|
||||
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
|
||||
|
||||
attn_sum_per_kv = []
|
||||
for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks):
|
||||
kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size
|
||||
attn_sum_kv = softmax_normalize_and_block_sum(
|
||||
attn_weights_kv,
|
||||
m_global,
|
||||
l_global,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
chunk_start=chunk_start,
|
||||
real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||
scale=scale,
|
||||
kv_offset=kv_offset_reshaped,
|
||||
is_causal=True,
|
||||
)
|
||||
attn_sum_per_kv.append(attn_sum_kv)
|
||||
|
||||
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
|
||||
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum_concat,
|
||||
current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk,
|
||||
threshold=THRESHOLD,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=True,
|
||||
)
|
||||
simple_mask_list.append(simple_mask)
|
||||
|
||||
mask_kv_chunking = torch.cat(simple_mask_list, dim=2)
|
||||
|
||||
# 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行)
|
||||
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0),
|
||||
mask_kv_chunking[:, :, -q_block_num:, -q_block_num:],
|
||||
False,
|
||||
)
|
||||
|
||||
mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks]
|
||||
selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
density_kv = selected_kv / total_api
|
||||
|
||||
mask_total = mask_api_valid.numel()
|
||||
mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item()
|
||||
mask_diff_pct = 100 * mask_diff / mask_total
|
||||
|
||||
return {
|
||||
"seq_len": seq_len,
|
||||
"stride": STRIDE,
|
||||
"threshold": THRESHOLD,
|
||||
"kv_chunks": kv_chunk_num,
|
||||
"density_api": density_api,
|
||||
"density_kv": density_kv,
|
||||
"density_diff": abs(density_api - density_kv),
|
||||
"mask_diff_pct": mask_diff_pct,
|
||||
"passed": abs(density_api - density_kv) < 1e-6 and mask_diff_pct < 0.01,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
files = sorted(glob.glob(os.path.join(DATA_DIR, "qkv_*.pt")))
|
||||
|
||||
print("=" * 80)
|
||||
print("XAttention KV Chunking Alignment Test")
|
||||
print("=" * 80)
|
||||
print()
|
||||
|
||||
results = []
|
||||
for f in files:
|
||||
fname = os.path.basename(f)
|
||||
print(f"Testing {fname}...", end=" ", flush=True)
|
||||
try:
|
||||
r = test_single_file(f)
|
||||
results.append(r)
|
||||
status = "✓ PASS" if r["passed"] else "✗ FAIL"
|
||||
print(f"{status} (seq_len={r['seq_len']}, kv_chunks={r['kv_chunks']})")
|
||||
except Exception as e:
|
||||
print(f"✗ ERROR: {e}")
|
||||
results.append({"file": fname, "error": str(e)})
|
||||
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("Results Summary")
|
||||
print("=" * 80)
|
||||
print()
|
||||
print("| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |")
|
||||
print("|---------|--------|-----------|-----------|-------------|------------|------|-----------|--------|")
|
||||
|
||||
all_passed = True
|
||||
for r in results:
|
||||
if "error" in r:
|
||||
print(f"| ERROR | - | - | - | - | - | - | - | {r['error'][:20]} |")
|
||||
all_passed = False
|
||||
else:
|
||||
status = "PASS" if r["passed"] else "FAIL"
|
||||
if not r["passed"]:
|
||||
all_passed = False
|
||||
print(f"| {r['seq_len']:>7} | {r['stride']:>6} | {r['threshold']:.2f} | {r['kv_chunks']:>9} | "
|
||||
f"{r['density_api']:.6f} | {r['density_kv']:.6f} | {r['density_diff']:.6f} | "
|
||||
f"{r['mask_diff_pct']:.4f}% | {status} |")
|
||||
|
||||
print()
|
||||
if all_passed:
|
||||
print("test_xattn_kv_chunking_batch: ALL PASSED")
|
||||
else:
|
||||
print("test_xattn_kv_chunking_batch: SOME FAILED")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user