diff --git a/tests/test_chunk_attention_graph.py b/tests/test_chunk_attention_graph.py deleted file mode 100644 index 00c18f7..0000000 --- a/tests/test_chunk_attention_graph.py +++ /dev/null @@ -1,151 +0,0 @@ -#!/usr/bin/env python3 -""" -Test: Pre-allocated chunk pair graphs for block sparse attention. - -Each (Q_chunk, K_chunk) pair has its own captured CUDA graph. -Zero copy_() during replay - all data pre-filled. - -Usage: - CUDA_VISIBLE_DEVICES=0 python tests/test_chunk_attention_graph.py -""" - -from dataclasses import dataclass -from typing import List, Optional - -import torch - -from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs - - -@dataclass -class ChunkAttentionGraph: - """Container for a captured chunk attention graph.""" - graph: torch.cuda.CUDAGraph - static_q: torch.Tensor - static_k: torch.Tensor - static_v: torch.Tensor - static_output: torch.Tensor - static_lse: torch.Tensor - causal: bool - - -def capture_chunk_attention_graph( - chunk_size: int, - num_heads: int, - num_kv_heads: int, - head_dim: int, - scale: float, - device: torch.device, - dtype: torch.dtype, - causal: bool = False, -) -> ChunkAttentionGraph: - """Capture a CUDA graph for single chunk attention.""" - static_q = torch.zeros(1, chunk_size, num_heads, head_dim, dtype=dtype, device=device) - static_k = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device) - static_v = torch.zeros(1, chunk_size, num_kv_heads, head_dim, dtype=dtype, device=device) - - static_q.normal_() - static_k.normal_() - static_v.normal_() - - # Warmup - with torch.inference_mode(): - for _ in range(3): - _ = flash_attn_with_lse(static_q, static_k, static_v, scale, causal) - torch.cuda.synchronize() - - # Capture - graph = torch.cuda.CUDAGraph() - with torch.inference_mode(): - with torch.cuda.graph(graph): - static_output, static_lse = flash_attn_with_lse(static_q, static_k, static_v, scale, causal) - - torch.cuda.synchronize() - - return ChunkAttentionGraph( - graph=graph, - static_q=static_q, - static_k=static_k, - static_v=static_v, - static_output=static_output, - static_lse=static_lse, - causal=causal, - ) - - -def main(): - device = torch.device("cuda") - dtype = torch.bfloat16 - - chunk_size = 64 - num_chunks = 4 - num_heads = 8 - num_kv_heads = 8 - head_dim = 64 - scale = 1.0 / (head_dim ** 0.5) - seq_len = chunk_size * num_chunks - - print(f"Device: {torch.cuda.get_device_name()}") - print(f"Chunk size: {chunk_size}, Num chunks: {num_chunks}") - print(f"Total graphs: {num_chunks * (num_chunks + 1) // 2}") - - # Test data - full_q = torch.randn(1, seq_len, num_heads, head_dim, dtype=dtype, device=device) - full_k = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device) - full_v = torch.randn(1, seq_len, num_kv_heads, head_dim, dtype=dtype, device=device) - - # Reference - with torch.inference_mode(): - full_output, _ = flash_attn_with_lse(full_q, full_k, full_v, scale, causal=True) - - # Capture all graphs - graphs: List[List[Optional[ChunkAttentionGraph]]] = [[None] * num_chunks for _ in range(num_chunks)] - for q_idx in range(num_chunks): - for k_idx in range(q_idx + 1): - graphs[q_idx][k_idx] = capture_chunk_attention_graph( - chunk_size, num_heads, num_kv_heads, head_dim, scale, device, dtype, - causal=(k_idx == q_idx) - ) - print("All graphs captured") - - # Pre-fill static tensors - for q_idx in range(num_chunks): - for k_idx in range(q_idx + 1): - g = graphs[q_idx][k_idx] - g.static_q.copy_(full_q[:, q_idx*chunk_size:(q_idx+1)*chunk_size]) - g.static_k.copy_(full_k[:, k_idx*chunk_size:(k_idx+1)*chunk_size]) - g.static_v.copy_(full_v[:, k_idx*chunk_size:(k_idx+1)*chunk_size]) - print("Static tensors pre-filled") - - # Replay and merge - chunked_output = torch.zeros_like(full_output) - for q_idx in range(num_chunks): - acc_out, acc_lse = None, None - for k_idx in range(q_idx + 1): - g = graphs[q_idx][k_idx] - g.graph.replay() - out, lse = g.static_output.clone(), g.static_lse.clone() - if acc_out is None: - acc_out, acc_lse = out, lse - else: - with torch.inference_mode(): - acc_out, acc_lse = merge_attention_outputs(acc_out, acc_lse, out, lse) - chunked_output[:, q_idx*chunk_size:(q_idx+1)*chunk_size] = acc_out - - torch.cuda.synchronize() - - # Compare - all_pass = True - for q_idx in range(num_chunks): - s, e = q_idx * chunk_size, (q_idx + 1) * chunk_size - diff = (full_output[:, s:e] - chunked_output[:, s:e]).abs().max().item() - status = "✅" if diff < 1e-2 else "❌" - print(f"Q[{q_idx}]: max_diff={diff:.2e} {status}") - if diff >= 1e-2: - all_pass = False - - print("✅ PASSED" if all_pass else "❌ FAILED") - - -if __name__ == "__main__": - main() diff --git a/tests/test_xattn_bsa.py b/tests/test_xattn_bsa.py deleted file mode 100644 index cd6529a..0000000 --- a/tests/test_xattn_bsa.py +++ /dev/null @@ -1,334 +0,0 @@ -""" -Test XAttention + BSA with RULER benchmark data. - -Tests XAttention sparse attention correctness using RULER NIAH task. - -Attention methods: - - Prefill: XAttention + BSA (sparse) or FlashAttention (dense) - - Decode: FlashAttention (always, since q_len=1) - -Usage (in compass conda env with BSA available): - CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH \ - python tests/test_xattn_bsa.py --model ~/models/Llama-3.1-8B-Instruct - - # Test with XAttention + BSA for prefill (default) - python tests/test_xattn_bsa.py --prefill-method xattn - - # Test with FlashAttention for prefill (baseline) - python tests/test_xattn_bsa.py --prefill-method flash - - # Test specific sample(s) - python tests/test_xattn_bsa.py --sample-id 0 - python tests/test_xattn_bsa.py --sample-ids 0,1,2 - -Note: Compatible with transformers 4.53+ (handles both old `past_key_value` - and new `past_key_values` API). -""" - -import argparse -import json -import sys -import torch -from pathlib import Path -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.cache_utils import DynamicCache - -from nanovllm.ops.xattn import xattn_estimate -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb - - -# ============================================================ -# XAttention + BSA Functions -# ============================================================ - -def expand_kv_for_gqa(key_states, value_states, num_heads): - """Expand KV for Grouped Query Attention.""" - num_kv_heads = key_states.shape[1] - if num_heads == num_kv_heads: - return key_states, value_states - num_groups = num_heads // num_kv_heads - return key_states.repeat_interleave(num_groups, dim=1), value_states.repeat_interleave(num_groups, dim=1) - - -def flash_attention_forward(query_states, key_states, value_states, is_causal=True): - """Standard FlashAttention.""" - from flash_attn import flash_attn_func - q = query_states.transpose(1, 2) - k = key_states.transpose(1, 2) - v = value_states.transpose(1, 2) - return flash_attn_func(q, k, v, causal=is_causal).transpose(1, 2) - - -def xattn_bsa_forward(query_states, key_states, value_states, threshold=0.9): - """XAttention + BSA sparse attention.""" - from block_sparse_attn import block_sparse_attn_func - - batch_size, num_heads, q_len, head_dim = query_states.shape - k_len = key_states.shape[2] - - _, mask = xattn_estimate( - query_states, key_states, - chunk_size=16384, block_size=128, threshold=threshold, - use_triton=True, causal=True, - ) - - q_block_num = (q_len + 127) // 128 - k_block_num = (k_len + 127) // 128 - - q = query_states.transpose(1, 2).reshape(q_len, num_heads, head_dim) - k = key_states.transpose(1, 2).reshape(k_len, num_heads, head_dim) - v = value_states.transpose(1, 2).reshape(k_len, num_heads, head_dim) - - __import__('pdb').set_trace() - - output = block_sparse_attn_func( - q, k, v, - torch.tensor([0, q_len], dtype=torch.int32, device=q.device), - torch.tensor([0, k_len], dtype=torch.int32, device=k.device), - torch.ones(num_heads, dtype=torch.int32, device=q.device), - None, - mask[:, :, :q_block_num, :k_block_num].contiguous(), - q_len, k_len, - p_dropout=0.0, deterministic=True, is_causal=True, - ) - return output.reshape(batch_size, q_len, num_heads, head_dim).transpose(1, 2) - - -DEBUG = False # Set to True to enable debugging - -def create_patched_forward(prefill_method="xattn", threshold=0.9): - """Create patched forward with configurable prefill method. - - Args: - prefill_method: "xattn" for XAttention + BSA (sparse), "flash" for FlashAttention (dense) - threshold: XAttention threshold for block selection (only used when prefill_method="xattn") - - Note: - - Prefill (q_len > 1): Uses specified prefill_method - - Decode (q_len = 1): Always uses FlashAttention (no sparse needed for single query) - """ - call_count = [0] # Mutable to track calls across layers - - def patched_forward( - self, - hidden_states, - position_embeddings=None, - attention_mask=None, - past_key_value=None, # Old API (transformers < 4.57) - past_key_values=None, # New API (transformers >= 4.57) - cache_position=None, - **kwargs - ): - # Handle both old and new transformers API - kv_cache = past_key_values if past_key_values is not None else past_key_value - - bsz, q_len, _ = hidden_states.size() - num_heads = self.config.num_attention_heads - num_kv_heads = self.config.num_key_value_heads - head_dim = self.head_dim - - # Compute Q, K, V projections - query_states = self.q_proj(hidden_states).view(bsz, q_len, num_heads, head_dim).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, q_len, num_kv_heads, head_dim).transpose(1, 2) - - # Apply rotary position embedding - cos, sin = position_embeddings - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - - # Handle KV cache - if kv_cache is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = kv_cache.update( - key_states, value_states, self.layer_idx, cache_kwargs - ) - - # Expand KV for GQA - key_states_exp, value_states_exp = expand_kv_for_gqa(key_states, value_states, num_heads) - - # Debug output - if DEBUG and self.layer_idx == 0: - call_count[0] += 1 - if call_count[0] <= 5: - phase = "prefill" if q_len > 1 else "decode" - print(f"\n[DEBUG] Layer {self.layer_idx}, call {call_count[0]} ({phase}): q_len={q_len}, k_len={key_states_exp.shape[2]}") - print(f" kv_cache is None: {kv_cache is None}") - - # Choose attention method: - # - Prefill (q_len > 1): Use prefill_method (xattn or flash) - # - Decode (q_len = 1): Always use FlashAttention - is_prefill = q_len > 1 - - if is_prefill and prefill_method == "xattn": - # Prefill with XAttention + BSA (sparse) - attn_output = xattn_bsa_forward(query_states, key_states_exp, value_states_exp, threshold) - else: - # Prefill with FlashAttention (dense) OR Decode (always FlashAttention) - # Note: For decode (q_len=1), causal=False since single query attends to all KV - attn_output = flash_attention_forward(query_states, key_states_exp, value_states_exp, is_causal=is_prefill) - - attn_output = self.o_proj(attn_output.transpose(1, 2).reshape(bsz, q_len, -1)) - return attn_output, None - - return patched_forward - - -# ============================================================ -# Data & Evaluation -# ============================================================ - -def load_samples(filepath, indices=None): - """Load samples from JSONL file.""" - samples = [] - with open(filepath) as f: - for i, line in enumerate(f): - if indices is None or i in indices: - sample = json.loads(line) - sample["_idx"] = i - samples.append(sample) - return samples - - -def string_match_all(output_text, expected_list): - """RULER metric: fraction of expected values found in output.""" - output_lower = output_text.lower().replace('\n', ' ') - if not expected_list: - return 1.0 - return sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list) / len(expected_list) - - -# ============================================================ -# Test -# ============================================================ - -def test_with_ruler_data(model_path, data_file, sample_ids, prefill_method="xattn", threshold=0.9, max_new_tokens=50): - """Test attention methods using RULER data. - - Args: - prefill_method: "xattn" for XAttention + BSA, "flash" for FlashAttention - """ - prefill_desc = "XAttention + BSA (sparse)" if prefill_method == "xattn" else "FlashAttention (dense)" - - print("=" * 60) - print("RULER NIAH Attention Test") - print("=" * 60) - print(f"Data: {data_file}") - print(f"Samples: {sample_ids}") - print(f"Prefill method: {prefill_desc}") - print(f"Decode method: FlashAttention (always)") - if prefill_method == "xattn": - print(f"XAttention threshold: {threshold}") - - samples = load_samples(Path(data_file), set(sample_ids) if sample_ids else None) - if not samples: - print("No samples found!") - return False - print(f"Loaded {len(samples)} samples") - - # Load model - print(f"\nLoading model: {model_path}") - tokenizer = AutoTokenizer.from_pretrained(model_path) - model = AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=torch.float16, device_map="cuda", - attn_implementation="eager", # Will be patched - ) - model.eval() - - # Patch all layers - print(f"Patching attention layers...") - print(f" - Prefill: {prefill_desc}") - print(f" - Decode: FlashAttention") - for idx, layer in enumerate(model.model.layers): - layer.self_attn.layer_idx = idx # Ensure layer_idx is set - layer.self_attn.forward = create_patched_forward(prefill_method, threshold).__get__( - layer.self_attn, type(layer.self_attn) - ) - - total_score = 0.0 - results = [] - - for sample in samples: - idx = sample["_idx"] - prompt = sample["input"] - expected = sample["outputs"] - - inputs = tokenizer(prompt, return_tensors="pt").to("cuda") - num_tokens = inputs["input_ids"].shape[1] - print(f"\n--- Sample {idx} ({num_tokens} tokens) ---") - print(f"Expected: {expected}") - - with torch.no_grad(): - output = model.generate( - inputs["input_ids"], - max_new_tokens=max_new_tokens, - do_sample=False, - pad_token_id=tokenizer.eos_token_id, - ) - output_text = tokenizer.decode(output[0][num_tokens:], skip_special_tokens=True) - score = string_match_all(output_text, expected) - total_score += score - - status = "✓ PASS" if score >= 0.5 else "✗ FAIL" - print(f"Output: '{output_text[:100]}...'") - print(f"Result: {status} (score={score:.2f})") - results.append({"idx": idx, "score": score, "passed": score >= 0.5}) - - avg_score = total_score / len(samples) - passed = sum(1 for r in results if r["passed"]) - - print(f"\n{'='*60}") - print(f"Results: {passed}/{len(samples)} passed, avg_score={avg_score:.3f}") - print(f"{'='*60}") - - return avg_score >= 0.5 - - -def main(): - parser = argparse.ArgumentParser( - description="Test XAttention + BSA vs FlashAttention for prefill using RULER NIAH benchmark" - ) - parser.add_argument("--model", default="~/models/Llama-3.1-8B-Instruct") - parser.add_argument("--data-file", default="tests/data/ruler_32k/niah_single_1/validation.jsonl") - parser.add_argument("--sample-id", type=int, default=None, help="Test single sample by index") - parser.add_argument("--sample-ids", type=str, default="", help="Test multiple samples (comma-separated)") - parser.add_argument("--prefill-method", choices=["xattn", "flash"], default="xattn", - help="Prefill attention method: xattn (XAttention+BSA sparse) or flash (FlashAttention dense)") - parser.add_argument("--threshold", type=float, default=0.9, help="XAttention threshold (only for --prefill-method xattn)") - parser.add_argument("--max-new-tokens", type=int, default=50) - # Keep old option for backwards compatibility - parser.add_argument("--no-xattn", action="store_true", help="[Deprecated] Use --prefill-method flash instead") - args = parser.parse_args() - - model_path = args.model.replace("~", "/home/zijie") - - # Handle deprecated --no-xattn option - prefill_method = args.prefill_method - if args.no_xattn: - prefill_method = "flash" - print("Warning: --no-xattn is deprecated, use --prefill-method flash instead") - - if args.sample_id is not None: - sample_ids = [args.sample_id] - elif args.sample_ids: - sample_ids = [int(x) for x in args.sample_ids.split(",")] - else: - sample_ids = [0] - - # Check BSA availability if using xattn - if prefill_method == "xattn": - try: - from block_sparse_attn import block_sparse_attn_func - print("✓ BSA (Block Sparse Attention) available") - except ImportError: - print("✗ BSA not found. Install block_sparse_attn or use --prefill-method flash") - sys.exit(1) - - if test_with_ruler_data(model_path, args.data_file, sample_ids, prefill_method, args.threshold, args.max_new_tokens): - print("\ntest_xattn_bsa: PASSED") - else: - print("\ntest_xattn_bsa: FAILED") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/tests/test_xattn_chunked.py b/tests/test_xattn_chunked.py deleted file mode 100644 index d6fc4c6..0000000 --- a/tests/test_xattn_chunked.py +++ /dev/null @@ -1,259 +0,0 @@ -""" -Test: Compare xattn_estimate vs xattn_estimate_chunked -Verify that chunked estimation with EXTERNAL chunking produces the same mask as standard estimation. - -Uses real QKV data captured from model inference. -""" - -import sys -import os -import torch -import warnings - -from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked - -# ============================================================ -# Configuration -# ============================================================ - -BLOCK_SIZE = 64 -STRIDE = 4 -THRESHOLD = 0.9 -CHUNK_SIZE = 4096 - -# Default QKV data directory (relative to project root) -DEFAULT_QKV_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "results", "kvcache") - -# ============================================================ -# Utility Functions -# ============================================================ - -def load_qkv(path): - """Load saved QKV data.""" - data = torch.load(path, map_location="cpu", weights_only=False) - print(f"Loaded: {path}") - print(f" Query shape: {data['query'].shape}") - print(f" Key shape: {data['key'].shape}") - print(f" Layer: {data['layer_id']}, Density: {data['density']:.2%}") - return data - - -def compare_masks(mask1, mask2, name1="standard", name2="chunked"): - """Compare two masks and report differences.""" - if mask1.shape != mask2.shape: - print(f"Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}") - return False - - diff = (mask1 != mask2).sum().item() - total = mask1.numel() - match_rate = (total - diff) / total * 100 - - print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})") - - if diff > 0: - diff_indices = torch.where(mask1 != mask2) - print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}") - - return diff == 0 - - -def run_chunked_externally(query, key, q_start_pos, block_size, stride, threshold, chunk_size): - """ - Run xattn_estimate_chunked with EXTERNAL chunking. - This simulates how chunked prefill should be used in practice. - """ - batch_size, num_heads, q_len, head_dim = query.shape - _, _, k_len, _ = key.shape - - q_block_num = (q_len + block_size - 1) // block_size - k_block_num = (k_len + block_size - 1) // block_size - - # If Q fits in one chunk, call directly - if q_len <= chunk_size: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - return xattn_estimate_chunked( - query, key, - q_start_pos=q_start_pos, - block_size=block_size, - stride=stride, - threshold=threshold, - use_triton=True, - chunk_size=chunk_size, - ) - - # External chunking: split Q and call for each chunk - num_q_chunks = (q_len + chunk_size - 1) // chunk_size - print(f" External chunking: {num_q_chunks} chunks") - - combined_attn_sum = torch.zeros( - batch_size, num_heads, q_block_num, k_block_num, - dtype=query.dtype, device=query.device - ) - combined_mask = torch.zeros( - batch_size, num_heads, q_block_num, k_block_num, - dtype=torch.bool, device=query.device - ) - - q_block_offset = 0 - for q_chunk_idx in range(num_q_chunks): - q_chunk_start = q_chunk_idx * chunk_size - q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len) - - q_chunk = query[:, :, q_chunk_start:q_chunk_end, :] - - # For causal attention, K accumulates up to current Q position - k_end = q_start_pos + q_chunk_end - k_chunk = key[:, :, :k_end, :] - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - attn_sum_chunk, mask_chunk = xattn_estimate_chunked( - q_chunk, k_chunk, - q_start_pos=q_start_pos + q_chunk_start, - block_size=block_size, - stride=stride, - threshold=threshold, - use_triton=True, - chunk_size=chunk_size, - ) - - # Place chunk results into combined output - chunk_q_blocks = mask_chunk.shape[2] - chunk_k_blocks = mask_chunk.shape[3] - combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk - combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk - q_block_offset += chunk_q_blocks - - return combined_attn_sum, combined_mask - - -def test_single_qkv(qkv_path): - """Test a single QKV file.""" - data = load_qkv(qkv_path) - query = data["query"].cuda().to(torch.bfloat16) - key = data["key"].cuda().to(torch.bfloat16) - - seq_len = query.shape[2] - print(f"\nTesting with seq_len={seq_len}") - print("=" * 60) - - # Run standard xattn_estimate - print("[1] Running standard xattn_estimate...") - try: - attn_sum_std, mask_std = xattn_estimate( - query, key, - block_size=BLOCK_SIZE, - stride=STRIDE, - threshold=THRESHOLD, - chunk_size=CHUNK_SIZE, - use_triton=True, - ) - print(f" mask shape: {mask_std.shape}, density: {mask_std.float().mean().item():.4f}") - except Exception as e: - print(f" ERROR: {e}") - import traceback - traceback.print_exc() - return False - - # Run chunked xattn_estimate with EXTERNAL chunking - print("[2] Running chunked xattn_estimate (external chunking)...") - try: - attn_sum_chunked, mask_chunked = run_chunked_externally( - query, key, - q_start_pos=0, - block_size=BLOCK_SIZE, - stride=STRIDE, - threshold=THRESHOLD, - chunk_size=CHUNK_SIZE, - ) - print(f" mask shape: {mask_chunked.shape}, density: {mask_chunked.float().mean().item():.4f}") - except Exception as e: - print(f" ERROR: {e}") - import traceback - traceback.print_exc() - return False - - # Compare results - print("[3] Comparing results...") - chunked_q_blocks = mask_chunked.shape[2] - chunked_k_blocks = mask_chunked.shape[3] - - # Extract comparable region from standard mask - mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks] - - # Compare masks - masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked") - - # Compare attn_sums - attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks] - if attn_sum_std_comparable.shape == attn_sum_chunked.shape: - attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item() - print(f" Attn sum max diff: {attn_diff:.6f}") - else: - print(f" Attn sum shape mismatch") - - # Clean up GPU memory - del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked - torch.cuda.empty_cache() - - return masks_match - - -# ============================================================ -# Main Test -# ============================================================ - -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser(description="Test xattn_estimate vs xattn_estimate_chunked") - parser.add_argument("--qkv-dir", type=str, default=DEFAULT_QKV_DIR, - help="Directory containing QKV files") - args = parser.parse_args() - - # QKV files to test - qkv_files = [ - os.path.join(args.qkv_dir, "qkv_3688.pt"), # ~4K - os.path.join(args.qkv_dir, "qkv_7888.pt"), # ~8K - os.path.join(args.qkv_dir, "qkv_15685.pt"), # ~16K - os.path.join(args.qkv_dir, "qkv_32485.pt"), # ~32K - os.path.join(args.qkv_dir, "qkv_64891.pt"), # ~64K - ] - - available_files = [p for p in qkv_files if os.path.exists(p)] - - if not available_files: - print(f"No QKV file found in {args.qkv_dir}.") - print(f"Expected files: qkv_3688.pt, qkv_7888.pt, qkv_15685.pt, qkv_32485.pt, qkv_64891.pt") - sys.exit(1) - - print(f"Found {len(available_files)} QKV files to test") - print(f"Testing EXTERNAL chunking (chunk_size={CHUNK_SIZE})") - print(f"Using Triton kernels") - - all_passed = True - results = [] - - for qkv_path in available_files: - passed = test_single_qkv(qkv_path) - seq_len = int(os.path.basename(qkv_path).replace("qkv_", "").replace(".pt", "")) - results.append((seq_len, passed)) - if not passed: - all_passed = False - - # Summary - print("\n" + "=" * 60) - print("SUMMARY") - print("=" * 60) - for seq_len, passed in results: - status = "PASSED" if passed else "FAILED" - chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE - print(f" seq_len={seq_len} ({chunks} chunk{'s' if chunks > 1 else ''}): {status}") - - print("=" * 60) - if all_passed: - print("test_xattn_chunked: PASSED") - sys.exit(0) - else: - print("test_xattn_chunked: FAILED") - sys.exit(1) diff --git a/tests/test_xattn_estimate_chunked.py b/tests/test_xattn_estimate_chunked.py deleted file mode 100644 index 76cb664..0000000 --- a/tests/test_xattn_estimate_chunked.py +++ /dev/null @@ -1,244 +0,0 @@ -""" -Test: Compare xattn_estimate vs xattn_estimate_chunked - -Verify that chunked estimation with EXTERNAL chunking produces the same mask -as standard estimation. This ensures the chunked version can be used in -chunked prefill scenarios without accuracy loss. - -Usage: - CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \ - python tests/test_xattn_estimate_chunked.py -""" - -import sys -import traceback -import torch -from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked - -# ============================================================ -# Configuration -# ============================================================ - -# Configuration for xattn_estimate_chunked consistency test. -# Key requirements for 100% match: -# 1. Use matching chunk_size for both standard and chunked versions -# 2. Use same random seed for reproducibility -# Note: Tiny differences (~0.000001) may occur at boundary cases due to -# floating point precision in cumulative sum calculations. -BLOCK_SIZE = 64 -STRIDE = 4 -THRESHOLD = 0.9 -CHUNK_SIZE = 4096 # External chunking size - -# Test sequence lengths -TEST_SEQ_LENS = [4096, 8192, 16384, 32768] - -# ============================================================ -# Utility Functions -# ============================================================ - -def compare_masks(mask1, mask2, name1="standard", name2="chunked"): - """Compare two masks and report differences.""" - if mask1.shape != mask2.shape: - print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}") - return False - - diff = (mask1 != mask2).sum().item() - total = mask1.numel() - match_rate = (total - diff) / total * 100 - - print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})") - - if diff > 0: - diff_indices = torch.where(mask1 != mask2) - print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}") - - return diff == 0 - - -def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size): - """ - Run xattn_estimate_chunked with EXTERNAL chunking. - This simulates how chunked prefill should be used in practice. - """ - batch_size, num_heads, q_len, head_dim = query.shape - _, _, k_len, _ = key.shape - - q_block_num = (q_len + block_size - 1) // block_size - k_block_num = (k_len + block_size - 1) // block_size - - # If Q fits in one chunk, call directly - if q_len <= chunk_size: - return xattn_estimate_chunked( - query, key, - q_start_pos=0, - block_size=block_size, - stride=stride, - threshold=threshold, - use_triton=True, - chunk_size=chunk_size, - ) - - # External chunking: split Q and call for each chunk - num_q_chunks = (q_len + chunk_size - 1) // chunk_size - print(f" External chunking: {num_q_chunks} chunks") - - combined_attn_sum = torch.zeros( - batch_size, num_heads, q_block_num, k_block_num, - dtype=query.dtype, device=query.device - ) - combined_mask = torch.zeros( - batch_size, num_heads, q_block_num, k_block_num, - dtype=torch.bool, device=query.device - ) - - q_block_offset = 0 - for q_chunk_idx in range(num_q_chunks): - q_chunk_start = q_chunk_idx * chunk_size - q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len) - - q_chunk = query[:, :, q_chunk_start:q_chunk_end, :] - - # For causal attention, K accumulates up to current Q position - # q_start_pos=0 means Q starts at position 0 in the full sequence - # K is [0, q_chunk_end) for causal attention - k_end = q_chunk_end - k_chunk = key[:, :, :k_end, :] - - attn_sum_chunk, mask_chunk = xattn_estimate_chunked( - q_chunk, k_chunk, - q_start_pos=q_chunk_start, - block_size=block_size, - stride=stride, - threshold=threshold, - use_triton=True, - chunk_size=chunk_size, - ) - - # Place chunk results into combined output - chunk_q_blocks = mask_chunk.shape[2] - chunk_k_blocks = mask_chunk.shape[3] - combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk - combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk - q_block_offset += chunk_q_blocks - - return combined_attn_sum, combined_mask - - -def test_single_seq_len(seq_len, num_heads=32, head_dim=128): - """Test a single sequence length.""" - print(f"\nTesting seq_len={seq_len}") - print("=" * 60) - - # Generate random Q/K - query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16) - key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16) - - # Run standard xattn_estimate - print("[1] Running standard xattn_estimate...") - try: - attn_sum_std, mask_std = xattn_estimate( - query, key, - block_size=BLOCK_SIZE, - stride=STRIDE, - threshold=THRESHOLD, - chunk_size=CHUNK_SIZE, - use_triton=True, - causal=True, - ) - density_std = mask_std.float().mean().item() - print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}") - except Exception as e: - print(f" ERROR: {e}") - traceback.print_exc() - return False - - # Run chunked xattn_estimate with EXTERNAL chunking - print("[2] Running chunked xattn_estimate (external chunking)...") - try: - attn_sum_chunked, mask_chunked = run_chunked_externally( - query, key, - block_size=BLOCK_SIZE, - stride=STRIDE, - threshold=THRESHOLD, - chunk_size=CHUNK_SIZE, - ) - density_chunked = mask_chunked.float().mean().item() - print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}") - except Exception as e: - print(f" ERROR: {e}") - traceback.print_exc() - return False - - # Compare results - print("[3] Comparing results...") - chunked_q_blocks = mask_chunked.shape[2] - chunked_k_blocks = mask_chunked.shape[3] - - # Extract comparable region from standard mask - mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks] - - # Compare masks - masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked") - - # Compare attn_sums - attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks] - if attn_sum_std_comparable.shape == attn_sum_chunked.shape: - attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item() - print(f" Attn sum max diff: {attn_diff:.6f}") - else: - print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}") - - # Clean up GPU memory - del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked - torch.cuda.empty_cache() - - return masks_match - - -# ============================================================ -# Main Test -# ============================================================ - -if __name__ == "__main__": - print("XAttention Chunked vs Standard Test") - print("=" * 60) - print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}") - print(f"External chunk_size={CHUNK_SIZE}") - print() - - # Check CUDA availability - if not torch.cuda.is_available(): - print("CUDA not available!") - sys.exit(1) - - print(f"Using GPU: {torch.cuda.get_device_name(0)}") - print("✓ xattn_estimate imported") - print("✓ xattn_estimate_chunked imported") - - # Run tests - all_passed = True - results = [] - - for seq_len in TEST_SEQ_LENS: - passed = test_single_seq_len(seq_len) - chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE - results.append((seq_len, chunks, passed)) - if not passed: - all_passed = False - - # Summary - print("\n" + "=" * 60) - print("SUMMARY") - print("=" * 60) - for seq_len, chunks, passed in results: - status = "PASSED" if passed else "FAILED" - print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}") - - print("=" * 60) - if all_passed: - print("ALL TESTS PASSED!") - sys.exit(0) - else: - print("SOME TESTS FAILED!") - sys.exit(1) diff --git a/tests/test_xattn_kernels.py b/tests/test_xattn_kernels.py deleted file mode 100644 index 8e5fcfb..0000000 --- a/tests/test_xattn_kernels.py +++ /dev/null @@ -1,132 +0,0 @@ -""" -Test: XAttention Triton kernels - -演示 XAttention 的两个核心 Triton kernel: -1. flat_group_gemm_fuse_reshape: 计算 stride reshape 后的 attention scores (反对角线求和) -2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和 - -数据流: - Q [batch, heads, q_len, head_dim] - K [batch, heads, kv_len, head_dim] - ↓ flat_group_gemm_fuse_reshape - attn_scores [batch, heads, q_len/stride, kv_len/stride] - ↓ softmax_fuse_block_sum - block_sums [batch, heads, q_blocks, k_blocks] -""" -import torch -import sys -sys.path.insert(0, "/home/zijie/Code/nano-vllm") -from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum - -# ============================================================ -# 参数配置 -# ============================================================ - -# Triton 约束: q_len >= stride * BLOCK_M, kv_len >= stride * BLOCK_N -# A100: BLOCK_M = BLOCK_N = 128, 所以 min = 4 * 128 = 512 -# RTX 3090: BLOCK_M = BLOCK_N = 64, 所以 min = 4 * 64 = 256 -q_len = 512 -kv_len = 2048 -head_dim = 128 -stride = 4 -block_size = 128 # softmax block size (in reshaped space) -segment_size = 128 # Triton kernel 要求 segment_size >= block_size - -# ============================================================ -# 构造输入: 偶数位置=1, 奇数位置=2 -# ============================================================ - -Q = torch.zeros(1, 1, q_len, head_dim, dtype=torch.bfloat16).cuda() -K = torch.zeros(1, 1, kv_len, head_dim, dtype=torch.bfloat16).cuda() - -for i in range(q_len): - if i % 2 == 0: - Q[0, 0, i, :] = 1 * (i // stride + 1) - else: - Q[0, 0, i, :] = 2 * (i // stride + 1) - -for i in range(kv_len): - if i % 2 == 0: - K[0, 0, i, :] = 1 - else: - K[0, 0, i, :] = 2 - -# ============================================================ -# Step 1: flat_group_gemm_fuse_reshape (chunked along K) -# ============================================================ - -q_reshaped_len = q_len // stride # 128 -kv_reshaped_len = kv_len // stride # 512 - -# 将 K 沿着长度维度分成多个 chunk -k_chunk_size = 512 # 每个 chunk 512 tokens -num_k_chunks = kv_len // k_chunk_size # 4 chunks - -attn_scores_list = [] -for k_chunk_idx in range(num_k_chunks): - k_start = k_chunk_idx * k_chunk_size - k_end = k_start + k_chunk_size - K_chunk = K[:, :, k_start:k_end, :] # [1, 1, k_chunk_size, head_dim] - - # 对每个 K chunk 调用 flat_group_gemm_fuse_reshape - # 输出: [batch, heads, q_len/stride, k_chunk_size/stride] - attn_chunk = flat_group_gemm_fuse_reshape( - Q, K_chunk, stride, - chunk_start=0, - chunk_end=q_reshaped_len, - is_causal=True - ) - - __import__('pdb').set_trace() - - attn_scores_list.append(attn_chunk) - -# 拼接所有 K chunks 的结果 -# 每个 chunk: [1, 1, q_reshaped_len, k_chunk_size/stride] -# 拼接后: [1, 1, q_reshaped_len, kv_reshaped_len] -attn_scores = torch.cat(attn_scores_list, dim=-1) - -# 验证 shape: [batch, heads, q_len/stride, kv_len/stride] -assert attn_scores.shape == (1, 1, q_reshaped_len, kv_reshaped_len), \ - f"shape mismatch: {attn_scores.shape} != (1, 1, {q_reshaped_len}, {kv_reshaped_len})" - -# 验证: 反对角线求和 -# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4 -# 反对角线有 stride/2 对,再乘以 head_dim -expected_gemm = (2*1 + 1*2) * (stride // 2) * head_dim -actual_gemm = attn_scores[0, 0, 0, 0].item() -assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expected_gemm}" - -# ============================================================ -# Step 2: softmax_fuse_block_sum -# ============================================================ - -scale = 1.4426950408889634 # log2(e) for exp2 - -block_sums = softmax_fuse_block_sum( - attn_scores, - block_size, - segment_size, - chunk_start=0, - chunk_end=q_reshaped_len, - real_q_len=q_reshaped_len, - scale=scale, - is_causal=False -) - -# 验证 shape: [batch, heads, q_blocks, k_blocks] -q_blocks = q_reshaped_len // block_size # 128 / 128 = 1 -k_blocks = kv_reshaped_len // block_size # 512 / 128 = 4 -assert block_sums.shape == (1, 1, q_blocks, k_blocks), \ - f"shape mismatch: {block_sums.shape} != (1, 1, {q_blocks}, {k_blocks})" - -# 验证: 每个 block 的 softmax 结果求和 -# 所有 attn_scores 相同 → softmax 均匀分布 -# 每行对一个 K block 的贡献 = block_size / kv_reshaped_len -# 每个 Q block 有 block_size 行 -# block_sum = block_size * (block_size / kv_reshaped_len) -expected_sum = block_size * block_size / kv_reshaped_len -actual_sum = block_sums[0, 0, 0, 0].item() -assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}" - -print("test_xattn_kernels: PASSED") diff --git a/tests/test_xattn_kv_chunking_batch.py b/tests/test_xattn_kv_chunking_batch.py deleted file mode 100644 index 60c8288..0000000 --- a/tests/test_xattn_kv_chunking_batch.py +++ /dev/null @@ -1,246 +0,0 @@ -""" -Test: 批量验证 xattn_estimate 与 KV chunking kernels 的一致性 - -测试 results/kvcache 下所有保存的 QKV 数据 - -Usage: - CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \ - python tests/test_xattn_kv_chunking_batch.py -""" -import sys -sys.path.insert(0, "/home/zijie/Code/nano-vllm") - -import os -import glob -import torch -import math -from nanovllm.ops.xattn import ( - xattn_estimate, - flat_group_gemm_fuse_reshape, - softmax_compute_partial_stats, - softmax_normalize_and_block_sum, - merge_softmax_stats, - find_blocks_chunked, -) - -# ============================================================ -# 参数配置 -# ============================================================ -DATA_DIR = "/home/zijie/Code/nano-vllm/results/kvcache" -BSA_BLOCK_SIZE = 128 -CHUNK_SIZE = 16384 - -device = "cuda" - - -def test_single_file(data_file: str) -> dict: - """测试单个 kvcache 文件""" - data = torch.load(data_file, map_location="cpu") - Q = data["query"].to(device) - K = data["key"].to(device) - - batch_size, num_heads, seq_len, head_dim = Q.shape - STRIDE = data["stride"] - THRESHOLD = data["threshold"][0].item() if isinstance(data["threshold"], torch.Tensor) else data["threshold"] - - # ========== xattn_estimate API ========== - attn_sums_api, mask_api = xattn_estimate( - Q, K, - block_size=BSA_BLOCK_SIZE, - stride=STRIDE, - threshold=THRESHOLD, - chunk_size=CHUNK_SIZE, - causal=True, - ) - - q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE - k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE - mask_api_valid = mask_api[:, :, :q_blocks, :k_blocks] - - causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=device, dtype=torch.bool)) - total_api = causal_mask.sum().item() * batch_size * num_heads - selected_api = (mask_api_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() - density_api = selected_api / total_api - - # ========== 三阶段 KV Chunking ========== - k_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len - q_num_to_pad = ((seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE) * CHUNK_SIZE - seq_len - q_chunk_num = (seq_len + q_num_to_pad) // CHUNK_SIZE - kv_chunk_num = (seq_len + k_num_to_pad) // CHUNK_SIZE - - k_block_num = (seq_len + k_num_to_pad) // BSA_BLOCK_SIZE - q_block_num = (seq_len + q_num_to_pad) // BSA_BLOCK_SIZE - - reshaped_chunk_size = CHUNK_SIZE // STRIDE - reshaped_block_size = BSA_BLOCK_SIZE // STRIDE - k_reshaped_seq_len = (seq_len + k_num_to_pad) // STRIDE - k_reshaped_num_to_pad = k_num_to_pad // STRIDE - num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size - kv_reshaped_chunk_size = CHUNK_SIZE // STRIDE - - if k_num_to_pad > 0: - K_padded = torch.nn.functional.pad(K, (0, 0, 0, k_num_to_pad), value=0) - else: - K_padded = K - - if q_num_to_pad > 0: - Q_padded = torch.nn.functional.pad(Q, (0, 0, 0, q_num_to_pad), value=0) - else: - Q_padded = Q - - norm = 1.0 - scale = 1.4426950408889634 / math.sqrt(head_dim) / STRIDE / norm - - simple_mask_list = [] - - for q_chunk_idx in range(q_chunk_num): - q_start = q_chunk_idx * reshaped_chunk_size * STRIDE - q_end = q_start + reshaped_chunk_size * STRIDE - Q_chunk = Q_padded[:, :, q_start:q_end, :] - - chunk_start = (k_block_num - q_block_num) * reshaped_block_size + q_chunk_idx * reshaped_chunk_size - chunk_end = chunk_start + reshaped_chunk_size - - m_chunks = [] - l_chunks = [] - attn_weights_chunks = [] - - for kv_chunk_idx in range(kv_chunk_num): - kv_start = kv_chunk_idx * CHUNK_SIZE - kv_end = kv_start + CHUNK_SIZE - K_chunk = K_padded[:, :, kv_start:kv_end, :] - kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size - - attn_weights_kv = flat_group_gemm_fuse_reshape( - Q_chunk, K_chunk, STRIDE, - chunk_start=chunk_start, - chunk_end=chunk_end, - is_causal=False, - ) - attn_weights_chunks.append(attn_weights_kv) - - m_partial, l_partial = softmax_compute_partial_stats( - attn_weights_kv, - reshaped_block_size, - min(4096, reshaped_block_size), - scale, - chunk_start=chunk_start, - kv_offset=kv_offset_reshaped, - is_causal=True, - ) - m_chunks.append(m_partial) - l_chunks.append(l_partial) - - m_global, l_global = merge_softmax_stats(m_chunks, l_chunks) - - attn_sum_per_kv = [] - for kv_chunk_idx, attn_weights_kv in enumerate(attn_weights_chunks): - kv_offset_reshaped = kv_chunk_idx * kv_reshaped_chunk_size - attn_sum_kv = softmax_normalize_and_block_sum( - attn_weights_kv, - m_global, - l_global, - reshaped_block_size, - min(4096, reshaped_block_size), - chunk_start=chunk_start, - real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad, - scale=scale, - kv_offset=kv_offset_reshaped, - is_causal=True, - ) - attn_sum_per_kv.append(attn_sum_kv) - - attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1) - - simple_mask = find_blocks_chunked( - attn_sum_concat, - current_index=k_block_num - q_block_num + q_chunk_idx * num_blocks_per_chunk, - threshold=THRESHOLD, - num_to_choose=None, - decoding=False, - mode="prefill", - causal=True, - ) - simple_mask_list.append(simple_mask) - - mask_kv_chunking = torch.cat(simple_mask_list, dim=2) - - # 应用与 xattn_estimate 相同的 causal mask 后处理 (xattn.py 第 1300-1306 行) - mask_kv_chunking[:, :, -q_block_num:, -q_block_num:] = torch.where( - torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=device), diagonal=0), - mask_kv_chunking[:, :, -q_block_num:, -q_block_num:], - False, - ) - - mask_kv_chunking_valid = mask_kv_chunking[:, :, :q_blocks, :k_blocks] - selected_kv = (mask_kv_chunking_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() - density_kv = selected_kv / total_api - - mask_total = mask_api_valid.numel() - mask_diff = (mask_api_valid != mask_kv_chunking_valid).sum().item() - mask_diff_pct = 100 * mask_diff / mask_total - - return { - "seq_len": seq_len, - "stride": STRIDE, - "threshold": THRESHOLD, - "kv_chunks": kv_chunk_num, - "density_api": density_api, - "density_kv": density_kv, - "density_diff": abs(density_api - density_kv), - "mask_diff_pct": mask_diff_pct, - "passed": abs(density_api - density_kv) < 1e-6 and mask_diff_pct < 0.01, - } - - -def main(): - files = sorted(glob.glob(os.path.join(DATA_DIR, "qkv_*.pt"))) - - print("=" * 80) - print("XAttention KV Chunking Alignment Test") - print("=" * 80) - print() - - results = [] - for f in files: - fname = os.path.basename(f) - print(f"Testing {fname}...", end=" ", flush=True) - try: - r = test_single_file(f) - results.append(r) - status = "✓ PASS" if r["passed"] else "✗ FAIL" - print(f"{status} (seq_len={r['seq_len']}, kv_chunks={r['kv_chunks']})") - except Exception as e: - print(f"✗ ERROR: {e}") - results.append({"file": fname, "error": str(e)}) - - print() - print("=" * 80) - print("Results Summary") - print("=" * 80) - print() - print("| seq_len | stride | threshold | kv_chunks | density_api | density_kv | diff | mask_diff | status |") - print("|---------|--------|-----------|-----------|-------------|------------|------|-----------|--------|") - - all_passed = True - for r in results: - if "error" in r: - print(f"| ERROR | - | - | - | - | - | - | - | {r['error'][:20]} |") - all_passed = False - else: - status = "PASS" if r["passed"] else "FAIL" - if not r["passed"]: - all_passed = False - print(f"| {r['seq_len']:>7} | {r['stride']:>6} | {r['threshold']:.2f} | {r['kv_chunks']:>9} | " - f"{r['density_api']:.6f} | {r['density_kv']:.6f} | {r['density_diff']:.6f} | " - f"{r['mask_diff_pct']:.4f}% | {status} |") - - print() - if all_passed: - print("test_xattn_kv_chunking_batch: ALL PASSED") - else: - print("test_xattn_kv_chunking_batch: SOME FAILED") - - -if __name__ == "__main__": - main()