From aeed6ccdfb0ba4dfa22f87289511b76f4b0d69c5 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Mon, 2 Feb 2026 11:14:46 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=85=20test:=20add=20GPU-only=20density=20?= =?UTF-8?q?alignment=20verification=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add test to verify XAttention density calculation in GPU-only mode matches independent xattn_estimate calls. Changes: - Add tests/test_gpuonly_density_alignment.py: loads saved Q/K from xattn_bsa.py, calls xattn_estimate independently, compares results - Enhance debug save in xattn_bsa.py: now saves Q, K tensors and xattn_estimate parameters for external verification - Set _DEBUG_SAVE_MASK = False by default Usage: 1. Set _DEBUG_SAVE_MASK = True in xattn_bsa.py 2. Run GPU-only inference with XAttention (e.g., test_ruler.py) 3. Run tests/test_gpuonly_density_alignment.py to verify alignment Verified on 4k/8k/16k/32k/64k contexts - all pass with exact match. Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude Co-Authored-By: Happy --- nanovllm/kvcache/sparse/xattn_bsa.py | 16 ++- tests/test_gpuonly_density_alignment.py | 149 ++++++++++++++++++++++++ 2 files changed, 162 insertions(+), 3 deletions(-) create mode 100644 tests/test_gpuonly_density_alignment.py diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index b5fe93d..a0c098b 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) # Global storage for mask debugging -_DEBUG_SAVE_MASK = True # Set to True to save masks for comparison +_DEBUG_SAVE_MASK = False # Set to True to save masks for comparison _DEBUG_MASK_STORAGE = {} # Check BSA availability @@ -399,7 +399,7 @@ class XAttentionBSAPolicy(SparsePolicy): causal=True, ) - # Debug: Save mask and attention sums for comparison + # Debug: Save Q, K, mask, attn_sums for external verification if _DEBUG_SAVE_MASK and layer_id == 0: import os valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE @@ -410,14 +410,24 @@ class XAttentionBSAPolicy(SparsePolicy): os.makedirs(save_dir, exist_ok=True) save_path = f"{save_dir}/gpuonly_layer{layer_id}.pt" torch.save({ + # Input tensors (GQA-expanded) + "Q": Q.clone().cpu(), # [1, num_heads, q_len, head_dim] + "K": K_exp.clone().cpu(), # [1, num_heads, k_len, head_dim] + # xattn_estimate parameters + "chunk_size": self.chunk_size, + "block_size": self.BSA_BLOCK_SIZE, + "stride": self.stride, + "threshold": self.threshold, + # Output for comparison "mask": mask_valid.clone().cpu(), "attn_sums": attn_sums_valid.clone().cpu(), + # Metadata "q_len": q_len, "k_len": k_len, "valid_q_blocks": valid_q_blocks, "valid_k_blocks": valid_k_blocks, }, save_path) - logger.info(f"[DEBUG] Saved mask to {save_path}, shape={mask_valid.shape}") + logger.info(f"[DEBUG] Saved Q/K/mask to {save_path}, Q={Q.shape}, K={K_exp.shape}") # Compute block counts q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE diff --git a/tests/test_gpuonly_density_alignment.py b/tests/test_gpuonly_density_alignment.py new file mode 100644 index 0000000..d0201fe --- /dev/null +++ b/tests/test_gpuonly_density_alignment.py @@ -0,0 +1,149 @@ +""" +Test: GPU-only density alignment verification + +验证 xattn_bsa.py 中 GPU-only 路径的 density 计算是否与独立调用 xattn_estimate 一致。 + +流程: +1. 运行 GPU-only 推理,保存 Q, K, mask, attn_sums +2. 加载保存的数据,独立调用 xattn_estimate +3. 比较两者的 mask 和 density +""" +import torch +import sys +sys.path.insert(0, "/home/zijie/Code/nano-vllm") + +from nanovllm.ops.xattn import xattn_estimate + +# ============================================================ +# 参数配置 +# ============================================================ + +DATA_PATH = "/home/zijie/Code/nano-vllm/results/mask_alignment/gpuonly_layer0.pt" + +# ============================================================ +# 加载保存的数据 +# ============================================================ + +print(f"Loading data from {DATA_PATH}") +data = torch.load(DATA_PATH, weights_only=False) + +Q = data["Q"].cuda() # [1, num_heads, q_len, head_dim] +K = data["K"].cuda() # [1, num_heads, k_len, head_dim] +chunk_size = data["chunk_size"] +block_size = data["block_size"] +stride = data["stride"] +threshold = data["threshold"] +mask_saved = data["mask"] # [1, num_heads, valid_q_blocks, valid_k_blocks] +attn_sums_saved = data["attn_sums"] +q_len = data["q_len"] +k_len = data["k_len"] +valid_q_blocks = data["valid_q_blocks"] +valid_k_blocks = data["valid_k_blocks"] + +print(f"Q shape: {Q.shape}") +print(f"K shape: {K.shape}") +print(f"q_len: {q_len}, k_len: {k_len}") +print(f"chunk_size: {chunk_size}, block_size: {block_size}, stride: {stride}, threshold: {threshold}") +print(f"valid_q_blocks: {valid_q_blocks}, valid_k_blocks: {valid_k_blocks}") +print(f"mask_saved shape: {mask_saved.shape}") + +# ============================================================ +# 独立调用 xattn_estimate +# ============================================================ + +print("\nCalling xattn_estimate independently...") +attn_sums_ext, mask_ext = xattn_estimate( + Q, K, + chunk_size=chunk_size, + block_size=block_size, + stride=stride, + threshold=threshold, + use_triton=True, + causal=True, +) + +# Trim to valid blocks +mask_ext_valid = mask_ext[:, :, :valid_q_blocks, :valid_k_blocks] +attn_sums_ext_valid = attn_sums_ext[:, :, :valid_q_blocks, :valid_k_blocks] + +print(f"mask_ext shape: {mask_ext.shape}") +print(f"mask_ext_valid shape: {mask_ext_valid.shape}") + +# ============================================================ +# 比较 attn_sums +# ============================================================ + +print("\n" + "=" * 60) +print("Comparing attn_sums") +print("=" * 60) + +attn_sums_saved_gpu = attn_sums_saved.cuda() +attn_diff = (attn_sums_ext_valid - attn_sums_saved_gpu).abs() +print(f"attn_sums max diff: {attn_diff.max().item():.6e}") +print(f"attn_sums mean diff: {attn_diff.mean().item():.6e}") + +# Check if attn_sums match +attn_match = attn_diff.max().item() < 1e-4 +print(f"attn_sums match: {attn_match}") + +# ============================================================ +# 比较 mask +# ============================================================ + +print("\n" + "=" * 60) +print("Comparing mask") +print("=" * 60) + +mask_saved_gpu = mask_saved.cuda() +mask_match = (mask_ext_valid == mask_saved_gpu).all().item() +print(f"mask exact match: {mask_match}") + +if not mask_match: + diff_count = (mask_ext_valid != mask_saved_gpu).sum().item() + total_count = mask_ext_valid.numel() + print(f"mask diff count: {diff_count} / {total_count} ({diff_count/total_count*100:.2f}%)") + +# ============================================================ +# 计算 density +# ============================================================ + +print("\n" + "=" * 60) +print("Comparing density") +print("=" * 60) + +# 计算 causal mask +q_offset_blocks = valid_k_blocks - valid_q_blocks +indices = torch.arange(valid_k_blocks, device=mask_ext_valid.device).unsqueeze(0) +q_indices = torch.arange(valid_q_blocks, device=mask_ext_valid.device).unsqueeze(1) +causal_mask = indices <= (q_indices + q_offset_blocks) + +# Density from saved mask +total_saved = causal_mask.sum().item() * mask_saved_gpu.shape[0] * mask_saved_gpu.shape[1] +selected_saved = (mask_saved_gpu & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() +density_saved = selected_saved / total_saved + +# Density from external xattn_estimate +total_ext = causal_mask.sum().item() * mask_ext_valid.shape[0] * mask_ext_valid.shape[1] +selected_ext = (mask_ext_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() +density_ext = selected_ext / total_ext + +print(f"Saved density: {density_saved:.6f} (selected={selected_saved}, total={total_saved})") +print(f"External density: {density_ext:.6f} (selected={selected_ext}, total={total_ext})") +print(f"Density diff: {abs(density_saved - density_ext):.6f}") + +# ============================================================ +# 结论 +# ============================================================ + +print("\n" + "=" * 60) +print("RESULT") +print("=" * 60) + +if attn_match and mask_match: + print("✅ PASSED: GPU-only density matches external xattn_estimate") +else: + print("❌ FAILED: Mismatch detected") + if not attn_match: + print(" - attn_sums mismatch") + if not mask_match: + print(" - mask mismatch")