test: add GPU-only density alignment verification test

Add test to verify XAttention density calculation in GPU-only mode
matches independent xattn_estimate calls.

Changes:
- Add tests/test_gpuonly_density_alignment.py: loads saved Q/K from
  xattn_bsa.py, calls xattn_estimate independently, compares results
- Enhance debug save in xattn_bsa.py: now saves Q, K tensors and
  xattn_estimate parameters for external verification
- Set _DEBUG_SAVE_MASK = False by default

Usage:
1. Set _DEBUG_SAVE_MASK = True in xattn_bsa.py
2. Run GPU-only inference with XAttention (e.g., test_ruler.py)
3. Run tests/test_gpuonly_density_alignment.py to verify alignment

Verified on 4k/8k/16k/32k/64k contexts - all pass with exact match.

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-02-02 11:14:46 +08:00
parent 6c55c4d2a3
commit aeed6ccdfb
2 changed files with 162 additions and 3 deletions

View File

@@ -27,7 +27,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Global storage for mask debugging
_DEBUG_SAVE_MASK = True # Set to True to save masks for comparison
_DEBUG_SAVE_MASK = False # Set to True to save masks for comparison
_DEBUG_MASK_STORAGE = {}
# Check BSA availability
@@ -399,7 +399,7 @@ class XAttentionBSAPolicy(SparsePolicy):
causal=True,
)
# Debug: Save mask and attention sums for comparison
# Debug: Save Q, K, mask, attn_sums for external verification
if _DEBUG_SAVE_MASK and layer_id == 0:
import os
valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
@@ -410,14 +410,24 @@ class XAttentionBSAPolicy(SparsePolicy):
os.makedirs(save_dir, exist_ok=True)
save_path = f"{save_dir}/gpuonly_layer{layer_id}.pt"
torch.save({
# Input tensors (GQA-expanded)
"Q": Q.clone().cpu(), # [1, num_heads, q_len, head_dim]
"K": K_exp.clone().cpu(), # [1, num_heads, k_len, head_dim]
# xattn_estimate parameters
"chunk_size": self.chunk_size,
"block_size": self.BSA_BLOCK_SIZE,
"stride": self.stride,
"threshold": self.threshold,
# Output for comparison
"mask": mask_valid.clone().cpu(),
"attn_sums": attn_sums_valid.clone().cpu(),
# Metadata
"q_len": q_len,
"k_len": k_len,
"valid_q_blocks": valid_q_blocks,
"valid_k_blocks": valid_k_blocks,
}, save_path)
logger.info(f"[DEBUG] Saved mask to {save_path}, shape={mask_valid.shape}")
logger.info(f"[DEBUG] Saved Q/K/mask to {save_path}, Q={Q.shape}, K={K_exp.shape}")
# Compute block counts
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE