test: add GPU-only density alignment verification test

Add test to verify XAttention density calculation in GPU-only mode
matches independent xattn_estimate calls.

Changes:
- Add tests/test_gpuonly_density_alignment.py: loads saved Q/K from
  xattn_bsa.py, calls xattn_estimate independently, compares results
- Enhance debug save in xattn_bsa.py: now saves Q, K tensors and
  xattn_estimate parameters for external verification
- Set _DEBUG_SAVE_MASK = False by default

Usage:
1. Set _DEBUG_SAVE_MASK = True in xattn_bsa.py
2. Run GPU-only inference with XAttention (e.g., test_ruler.py)
3. Run tests/test_gpuonly_density_alignment.py to verify alignment

Verified on 4k/8k/16k/32k/64k contexts - all pass with exact match.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-02-02 11:14:46 +08:00
parent 6c55c4d2a3
commit aeed6ccdfb
2 changed files with 162 additions and 3 deletions

View File

@@ -27,7 +27,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Global storage for mask debugging # Global storage for mask debugging
_DEBUG_SAVE_MASK = True # Set to True to save masks for comparison _DEBUG_SAVE_MASK = False # Set to True to save masks for comparison
_DEBUG_MASK_STORAGE = {} _DEBUG_MASK_STORAGE = {}
# Check BSA availability # Check BSA availability
@@ -399,7 +399,7 @@ class XAttentionBSAPolicy(SparsePolicy):
causal=True, causal=True,
) )
# Debug: Save mask and attention sums for comparison # Debug: Save Q, K, mask, attn_sums for external verification
if _DEBUG_SAVE_MASK and layer_id == 0: if _DEBUG_SAVE_MASK and layer_id == 0:
import os import os
valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
@@ -410,14 +410,24 @@ class XAttentionBSAPolicy(SparsePolicy):
os.makedirs(save_dir, exist_ok=True) os.makedirs(save_dir, exist_ok=True)
save_path = f"{save_dir}/gpuonly_layer{layer_id}.pt" save_path = f"{save_dir}/gpuonly_layer{layer_id}.pt"
torch.save({ torch.save({
# Input tensors (GQA-expanded)
"Q": Q.clone().cpu(), # [1, num_heads, q_len, head_dim]
"K": K_exp.clone().cpu(), # [1, num_heads, k_len, head_dim]
# xattn_estimate parameters
"chunk_size": self.chunk_size,
"block_size": self.BSA_BLOCK_SIZE,
"stride": self.stride,
"threshold": self.threshold,
# Output for comparison
"mask": mask_valid.clone().cpu(), "mask": mask_valid.clone().cpu(),
"attn_sums": attn_sums_valid.clone().cpu(), "attn_sums": attn_sums_valid.clone().cpu(),
# Metadata
"q_len": q_len, "q_len": q_len,
"k_len": k_len, "k_len": k_len,
"valid_q_blocks": valid_q_blocks, "valid_q_blocks": valid_q_blocks,
"valid_k_blocks": valid_k_blocks, "valid_k_blocks": valid_k_blocks,
}, save_path) }, save_path)
logger.info(f"[DEBUG] Saved mask to {save_path}, shape={mask_valid.shape}") logger.info(f"[DEBUG] Saved Q/K/mask to {save_path}, Q={Q.shape}, K={K_exp.shape}")
# Compute block counts # Compute block counts
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE

View File

@@ -0,0 +1,149 @@
"""
Test: GPU-only density alignment verification
验证 xattn_bsa.py 中 GPU-only 路径的 density 计算是否与独立调用 xattn_estimate 一致。
流程:
1. 运行 GPU-only 推理,保存 Q, K, mask, attn_sums
2. 加载保存的数据,独立调用 xattn_estimate
3. 比较两者的 mask 和 density
"""
import torch
import sys
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
from nanovllm.ops.xattn import xattn_estimate
# ============================================================
# 参数配置
# ============================================================
DATA_PATH = "/home/zijie/Code/nano-vllm/results/mask_alignment/gpuonly_layer0.pt"
# ============================================================
# 加载保存的数据
# ============================================================
print(f"Loading data from {DATA_PATH}")
data = torch.load(DATA_PATH, weights_only=False)
Q = data["Q"].cuda() # [1, num_heads, q_len, head_dim]
K = data["K"].cuda() # [1, num_heads, k_len, head_dim]
chunk_size = data["chunk_size"]
block_size = data["block_size"]
stride = data["stride"]
threshold = data["threshold"]
mask_saved = data["mask"] # [1, num_heads, valid_q_blocks, valid_k_blocks]
attn_sums_saved = data["attn_sums"]
q_len = data["q_len"]
k_len = data["k_len"]
valid_q_blocks = data["valid_q_blocks"]
valid_k_blocks = data["valid_k_blocks"]
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"q_len: {q_len}, k_len: {k_len}")
print(f"chunk_size: {chunk_size}, block_size: {block_size}, stride: {stride}, threshold: {threshold}")
print(f"valid_q_blocks: {valid_q_blocks}, valid_k_blocks: {valid_k_blocks}")
print(f"mask_saved shape: {mask_saved.shape}")
# ============================================================
# 独立调用 xattn_estimate
# ============================================================
print("\nCalling xattn_estimate independently...")
attn_sums_ext, mask_ext = xattn_estimate(
Q, K,
chunk_size=chunk_size,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
causal=True,
)
# Trim to valid blocks
mask_ext_valid = mask_ext[:, :, :valid_q_blocks, :valid_k_blocks]
attn_sums_ext_valid = attn_sums_ext[:, :, :valid_q_blocks, :valid_k_blocks]
print(f"mask_ext shape: {mask_ext.shape}")
print(f"mask_ext_valid shape: {mask_ext_valid.shape}")
# ============================================================
# 比较 attn_sums
# ============================================================
print("\n" + "=" * 60)
print("Comparing attn_sums")
print("=" * 60)
attn_sums_saved_gpu = attn_sums_saved.cuda()
attn_diff = (attn_sums_ext_valid - attn_sums_saved_gpu).abs()
print(f"attn_sums max diff: {attn_diff.max().item():.6e}")
print(f"attn_sums mean diff: {attn_diff.mean().item():.6e}")
# Check if attn_sums match
attn_match = attn_diff.max().item() < 1e-4
print(f"attn_sums match: {attn_match}")
# ============================================================
# 比较 mask
# ============================================================
print("\n" + "=" * 60)
print("Comparing mask")
print("=" * 60)
mask_saved_gpu = mask_saved.cuda()
mask_match = (mask_ext_valid == mask_saved_gpu).all().item()
print(f"mask exact match: {mask_match}")
if not mask_match:
diff_count = (mask_ext_valid != mask_saved_gpu).sum().item()
total_count = mask_ext_valid.numel()
print(f"mask diff count: {diff_count} / {total_count} ({diff_count/total_count*100:.2f}%)")
# ============================================================
# 计算 density
# ============================================================
print("\n" + "=" * 60)
print("Comparing density")
print("=" * 60)
# 计算 causal mask
q_offset_blocks = valid_k_blocks - valid_q_blocks
indices = torch.arange(valid_k_blocks, device=mask_ext_valid.device).unsqueeze(0)
q_indices = torch.arange(valid_q_blocks, device=mask_ext_valid.device).unsqueeze(1)
causal_mask = indices <= (q_indices + q_offset_blocks)
# Density from saved mask
total_saved = causal_mask.sum().item() * mask_saved_gpu.shape[0] * mask_saved_gpu.shape[1]
selected_saved = (mask_saved_gpu & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_saved = selected_saved / total_saved
# Density from external xattn_estimate
total_ext = causal_mask.sum().item() * mask_ext_valid.shape[0] * mask_ext_valid.shape[1]
selected_ext = (mask_ext_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
density_ext = selected_ext / total_ext
print(f"Saved density: {density_saved:.6f} (selected={selected_saved}, total={total_saved})")
print(f"External density: {density_ext:.6f} (selected={selected_ext}, total={total_ext})")
print(f"Density diff: {abs(density_saved - density_ext):.6f}")
# ============================================================
# 结论
# ============================================================
print("\n" + "=" * 60)
print("RESULT")
print("=" * 60)
if attn_match and mask_match:
print("✅ PASSED: GPU-only density matches external xattn_estimate")
else:
print("❌ FAILED: Mismatch detected")
if not attn_match:
print(" - attn_sums mismatch")
if not mask_match:
print(" - mask mismatch")