✅ test: add GPU-only density alignment verification test
Add test to verify XAttention density calculation in GPU-only mode matches independent xattn_estimate calls. Changes: - Add tests/test_gpuonly_density_alignment.py: loads saved Q/K from xattn_bsa.py, calls xattn_estimate independently, compares results - Enhance debug save in xattn_bsa.py: now saves Q, K tensors and xattn_estimate parameters for external verification - Set _DEBUG_SAVE_MASK = False by default Usage: 1. Set _DEBUG_SAVE_MASK = True in xattn_bsa.py 2. Run GPU-only inference with XAttention (e.g., test_ruler.py) 3. Run tests/test_gpuonly_density_alignment.py to verify alignment Verified on 4k/8k/16k/32k/64k contexts - all pass with exact match. Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global storage for mask debugging
|
||||
_DEBUG_SAVE_MASK = True # Set to True to save masks for comparison
|
||||
_DEBUG_SAVE_MASK = False # Set to True to save masks for comparison
|
||||
_DEBUG_MASK_STORAGE = {}
|
||||
|
||||
# Check BSA availability
|
||||
@@ -399,7 +399,7 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Debug: Save mask and attention sums for comparison
|
||||
# Debug: Save Q, K, mask, attn_sums for external verification
|
||||
if _DEBUG_SAVE_MASK and layer_id == 0:
|
||||
import os
|
||||
valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
@@ -410,14 +410,24 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
save_path = f"{save_dir}/gpuonly_layer{layer_id}.pt"
|
||||
torch.save({
|
||||
# Input tensors (GQA-expanded)
|
||||
"Q": Q.clone().cpu(), # [1, num_heads, q_len, head_dim]
|
||||
"K": K_exp.clone().cpu(), # [1, num_heads, k_len, head_dim]
|
||||
# xattn_estimate parameters
|
||||
"chunk_size": self.chunk_size,
|
||||
"block_size": self.BSA_BLOCK_SIZE,
|
||||
"stride": self.stride,
|
||||
"threshold": self.threshold,
|
||||
# Output for comparison
|
||||
"mask": mask_valid.clone().cpu(),
|
||||
"attn_sums": attn_sums_valid.clone().cpu(),
|
||||
# Metadata
|
||||
"q_len": q_len,
|
||||
"k_len": k_len,
|
||||
"valid_q_blocks": valid_q_blocks,
|
||||
"valid_k_blocks": valid_k_blocks,
|
||||
}, save_path)
|
||||
logger.info(f"[DEBUG] Saved mask to {save_path}, shape={mask_valid.shape}")
|
||||
logger.info(f"[DEBUG] Saved Q/K/mask to {save_path}, Q={Q.shape}, K={K_exp.shape}")
|
||||
|
||||
# Compute block counts
|
||||
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
|
||||
Reference in New Issue
Block a user