🚧 WIP: add DEBUG code for XAttention KV chunking density verification

Add instrumentation to compare GPU-only vs Offload mode density:
- Layer 0 DEBUG output for both modes
- Accumulate selected/total counts across chunks
- Proper causal mask with Q offset handling
- Skip normal offload logic for isolated testing

Test results (threshold=1.0 achieves alignment):
- 32K: GPU-only 0.9999, Offload 0.9999 (diff ~0%)
- 64K: GPU-only 0.9995, Offload 0.9995 (diff ~0%)

Generated with [Claude Code](https://claude.ai/code)
via [Happy](https://happy.engineering)

Co-Authored-By: Claude <noreply@anthropic.com>
Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
Zijie Tian
2026-02-01 17:33:23 +08:00
parent 2e96d1d97d
commit 8ab53e7331

View File

@@ -147,8 +147,10 @@ class XAttentionBSAPolicy(SparsePolicy):
self._selected_cpu_indices: List[int] = []
self._bsa_per_cpu: int = 0 # BSA blocks per CPU block
#> Debug: store all K cache
#> Debug: store all K cache and density counts
self._debug_k_full: torch.Tensor | None = None
self._debug_selected: int = 0 # 累积的 selected blocks
self._debug_total: int = 0 # 累积的 total blocks
def alloc_policy_metadata(
self,
@@ -202,8 +204,10 @@ class XAttentionBSAPolicy(SparsePolicy):
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
#DEBUG : buffer for save all K cache.
#DEBUG : buffer for save all K cache
self._debug_k_full = torch.empty((1, num_heads, max_seq_len, head_dim), dtype=dtype, device=device)
self._debug_selected = 0
self._debug_total = 0
# =========================================================================
# GPU-only methods (non-chunked)
@@ -395,6 +399,15 @@ class XAttentionBSAPolicy(SparsePolicy):
)
# Record density for all layers via DensityObserver
if layer_id == 0:
# DEBUG: 打印 GPU-only Layer 0 的 mask 详情
q_bk = mask_trimmed.shape[2]
k_bk = mask_trimmed.shape[3]
causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1]
causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool))
selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
logger.info(f"[DEBUG GPU-only Layer0] mask_shape={mask_trimmed.shape}, "
f"density={selected/causal_total:.6f}, selected={selected}, total={causal_total}")
DensityObserver.record(layer_id, mask_trimmed, causal=True)
return output
@@ -567,9 +580,67 @@ class XAttentionBSAPolicy(SparsePolicy):
k_repeated = k.repeat_interleave(num_groups, dim=1).unsqueeze(0).transpose(1, 2) # [1, num_heads, historical_k_len, head_dim]
self._debug_k_full[:, :, historical_k_len:historical_k_len + q_len, :].copy_(k_repeated)
# ============================================================
# DEBUG: 累积 selected/total counts (仅 layer 0)
# 使用完整 K 调用 xattn_estimate与 GPU-only 逻辑一致
# ============================================================
if layer_id == 0:
__import__('pdb').set_trace()
from nanovllm.ops.xattn import xattn_estimate
total_k_len = historical_k_len + q_len
K_full = self._debug_k_full[:, :, :total_k_len, :]
# 用当前 Q chunk 和累积的 K 调用 xattn_estimate
# 设置 chunk_size 为 q_len 的最小对齐值 (stride * BLOCK_M = 8 * 128 = 1024)
alignment = self.stride * 128
aligned_chunk_size = ((q_len + alignment - 1) // alignment) * alignment
# DEBUG: 使用固定 threshold 测试
_, mask_chunk = xattn_estimate(
Q[:, :, :q_len, :], # 当前 Q chunk
K_full, # 累积的 K
block_size=self.BSA_BLOCK_SIZE,
stride=self.stride,
threshold=self.threshold, # DEBUG: 使用传入的 threshold
chunk_size=aligned_chunk_size, # 对齐的 chunk_size
causal=True,
)
# 计算有效的 block 数量(排除 padding
valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
valid_k_blocks = (total_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
# 裁剪 mask 到有效区域
mask_valid = mask_chunk[:, :, :valid_q_blocks, :valid_k_blocks]
# 计算当前 chunk 的 selected/total (考虑 causal考虑 Q 偏移量)
q_blocks = valid_q_blocks
k_blocks = valid_k_blocks
# Q 从位置 (k_blocks - q_blocks) 开始,所以 Q block i 实际位置是 i + offset
# Q block i (实际位置 i+offset) 可以看到 K block 0 到 i+offset
q_offset_blocks = k_blocks - q_blocks
indices = torch.arange(k_blocks, device=mask_valid.device).unsqueeze(0) # [1, k_blocks]
q_indices = torch.arange(q_blocks, device=mask_valid.device).unsqueeze(1) # [q_blocks, 1]
causal_mask = indices <= (q_indices + q_offset_blocks) # [q_blocks, k_blocks]
chunk_total = causal_mask.sum().item() * mask_valid.shape[0] * mask_valid.shape[1]
chunk_selected = (mask_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
# 累积
self._debug_selected += chunk_selected
self._debug_total += chunk_total
# 打印当前累积的 density
if self._debug_total > 0:
density = self._debug_selected / self._debug_total
logger.info(f"[DEBUG Offload Layer0] 累积 density: {density:.4f} "
f"(selected={self._debug_selected}, total={self._debug_total}, k_len={total_k_len}, "
f"mask_shape={mask_chunk.shape}, q_offset={q_offset_blocks})")
# DEBUG: 跳过正常 offload 逻辑,直接返回所有 blocks
return available_blocks
else:
# DEBUG: 非 Layer 0 也跳过正常 offload 逻辑
return available_blocks
# ============================================================
# Step 3: Get current chunk K and compute its attn_scores
@@ -656,14 +727,16 @@ class XAttentionBSAPolicy(SparsePolicy):
q_start_bsa_block = historical_k_bsa_blocks # Q starts after historical K
with nvtx.range("xattn_find_blocks"):
# 对于历史 K 的选择,使用 causal=False 因为历史 K 都在当前 Q 之前
# current_index=0 避免超出 block_sums 的 K 维度
mask = find_blocks_chunked(
block_sums,
current_index=q_start_bsa_block, # Q's position in BSA blocks
current_index=0,
threshold=self.threshold,
num_to_choose=None,
decoding=False,
mode="prefill",
causal=True, # Causal for block-level mask
mode="both",
causal=False,
)
# mask shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
@@ -676,47 +749,13 @@ class XAttentionBSAPolicy(SparsePolicy):
valid_q_bsa = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
valid_curr_k_bsa = (curr_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
# 7a: Record historical blocks density
# IMPORTANT: For historical blocks, apply causal mask to match GPU-only density calculation!
# Q block i (global position = q_start_bsa_block + i) can see historical K block j
# only if j <= q_start_bsa_block + i (causal constraint)
mask_historical = mask[:, :, :valid_q_bsa, :historical_k_bsa_blocks]
# 7a: Record historical blocks density (暂时禁用,使用 DEBUG 输出代替)
# if historical_k_bsa_blocks > 0:
# ... DensityObserver.record_counts ...
if historical_k_bsa_blocks > 0:
# Create causal mask for historical blocks
# Q_global[i] = q_start_bsa_block + i, K[j] = j
# Causal: j <= Q_global[i] => j <= q_start_bsa_block + i
q_global_indices = torch.arange(valid_q_bsa, device=mask.device) + q_start_bsa_block
k_indices = torch.arange(historical_k_bsa_blocks, device=mask.device)
# Q at position q_global_indices[i] can see K at position k_indices[j] if k_indices[j] <= q_global_indices[i]
causal_mask_historical = k_indices.unsqueeze(0) <= q_global_indices.unsqueeze(1) # [valid_q_bsa, historical_k_bsa_blocks]
# Count positions within causal mask only
total_historical_causal = causal_mask_historical.sum().item() * B * H
selected_historical = (mask_historical & causal_mask_historical.unsqueeze(0).unsqueeze(0)).sum().item()
if total_historical_causal > 0:
DensityObserver.record_counts(layer_id, selected_historical, total_historical_causal)
# 7b: Record current chunk density (causal, to align with GPU-only mode)
# Current chunk is the portion after historical blocks
if valid_curr_k_bsa > 0:
# Extract current chunk mask (only valid portion, not padded)
mask_current = mask[:, :, :valid_q_bsa, historical_k_bsa_blocks:historical_k_bsa_blocks + valid_curr_k_bsa]
q_dim = mask_current.shape[2]
k_dim = mask_current.shape[3]
# Create causal mask (lower triangular)
# For current chunk: Q[i] can see K[j] where j <= i (standard causal)
causal_mask = torch.tril(torch.ones(q_dim, k_dim, device=mask.device, dtype=torch.bool))
# Count positions within causal mask only
total_current_causal = causal_mask.sum().item() * B * H
selected_current = (mask_current & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
if total_current_causal > 0:
DensityObserver.record_counts(layer_id, selected_current, total_current_causal)
# 7b: Record current chunk density (暂时禁用)
# if valid_curr_k_bsa > 0:
# ... DensityObserver.record_counts ...
# Step 7.5: Save historical mask to pre-allocated buffer for compute_chunked_prefill
# Use full Q_bsa (padded) for buffer, not valid_q_bsa