🚧 WIP: add DEBUG code for XAttention KV chunking density verification
Add instrumentation to compare GPU-only vs Offload mode density: - Layer 0 DEBUG output for both modes - Accumulate selected/total counts across chunks - Proper causal mask with Q offset handling - Skip normal offload logic for isolated testing Test results (threshold=1.0 achieves alignment): - 32K: GPU-only 0.9999, Offload 0.9999 (diff ~0%) - 64K: GPU-only 0.9995, Offload 0.9995 (diff ~0%) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
This commit is contained in:
@@ -147,8 +147,10 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
self._selected_cpu_indices: List[int] = []
|
||||
self._bsa_per_cpu: int = 0 # BSA blocks per CPU block
|
||||
|
||||
#> Debug: store all K cache
|
||||
#> Debug: store all K cache and density counts
|
||||
self._debug_k_full: torch.Tensor | None = None
|
||||
self._debug_selected: int = 0 # 累积的 selected blocks
|
||||
self._debug_total: int = 0 # 累积的 total blocks
|
||||
|
||||
def alloc_policy_metadata(
|
||||
self,
|
||||
@@ -202,8 +204,10 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
|
||||
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
|
||||
|
||||
#DEBUG : buffer for save all K cache.
|
||||
#DEBUG : buffer for save all K cache
|
||||
self._debug_k_full = torch.empty((1, num_heads, max_seq_len, head_dim), dtype=dtype, device=device)
|
||||
self._debug_selected = 0
|
||||
self._debug_total = 0
|
||||
|
||||
# =========================================================================
|
||||
# GPU-only methods (non-chunked)
|
||||
@@ -395,6 +399,15 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
)
|
||||
|
||||
# Record density for all layers via DensityObserver
|
||||
if layer_id == 0:
|
||||
# DEBUG: 打印 GPU-only Layer 0 的 mask 详情
|
||||
q_bk = mask_trimmed.shape[2]
|
||||
k_bk = mask_trimmed.shape[3]
|
||||
causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1]
|
||||
causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool))
|
||||
selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
logger.info(f"[DEBUG GPU-only Layer0] mask_shape={mask_trimmed.shape}, "
|
||||
f"density={selected/causal_total:.6f}, selected={selected}, total={causal_total}")
|
||||
DensityObserver.record(layer_id, mask_trimmed, causal=True)
|
||||
|
||||
return output
|
||||
@@ -567,9 +580,67 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
k_repeated = k.repeat_interleave(num_groups, dim=1).unsqueeze(0).transpose(1, 2) # [1, num_heads, historical_k_len, head_dim]
|
||||
|
||||
self._debug_k_full[:, :, historical_k_len:historical_k_len + q_len, :].copy_(k_repeated)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# DEBUG: 累积 selected/total counts (仅 layer 0)
|
||||
# 使用完整 K 调用 xattn_estimate,与 GPU-only 逻辑一致
|
||||
# ============================================================
|
||||
if layer_id == 0:
|
||||
__import__('pdb').set_trace()
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
|
||||
total_k_len = historical_k_len + q_len
|
||||
K_full = self._debug_k_full[:, :, :total_k_len, :]
|
||||
|
||||
# 用当前 Q chunk 和累积的 K 调用 xattn_estimate
|
||||
# 设置 chunk_size 为 q_len 的最小对齐值 (stride * BLOCK_M = 8 * 128 = 1024)
|
||||
alignment = self.stride * 128
|
||||
aligned_chunk_size = ((q_len + alignment - 1) // alignment) * alignment
|
||||
# DEBUG: 使用固定 threshold 测试
|
||||
_, mask_chunk = xattn_estimate(
|
||||
Q[:, :, :q_len, :], # 当前 Q chunk
|
||||
K_full, # 累积的 K
|
||||
block_size=self.BSA_BLOCK_SIZE,
|
||||
stride=self.stride,
|
||||
threshold=self.threshold, # DEBUG: 使用传入的 threshold
|
||||
chunk_size=aligned_chunk_size, # 对齐的 chunk_size
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# 计算有效的 block 数量(排除 padding)
|
||||
valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
valid_k_blocks = (total_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
|
||||
# 裁剪 mask 到有效区域
|
||||
mask_valid = mask_chunk[:, :, :valid_q_blocks, :valid_k_blocks]
|
||||
|
||||
# 计算当前 chunk 的 selected/total (考虑 causal,考虑 Q 偏移量)
|
||||
q_blocks = valid_q_blocks
|
||||
k_blocks = valid_k_blocks
|
||||
# Q 从位置 (k_blocks - q_blocks) 开始,所以 Q block i 实际位置是 i + offset
|
||||
# Q block i (实际位置 i+offset) 可以看到 K block 0 到 i+offset
|
||||
q_offset_blocks = k_blocks - q_blocks
|
||||
indices = torch.arange(k_blocks, device=mask_valid.device).unsqueeze(0) # [1, k_blocks]
|
||||
q_indices = torch.arange(q_blocks, device=mask_valid.device).unsqueeze(1) # [q_blocks, 1]
|
||||
causal_mask = indices <= (q_indices + q_offset_blocks) # [q_blocks, k_blocks]
|
||||
chunk_total = causal_mask.sum().item() * mask_valid.shape[0] * mask_valid.shape[1]
|
||||
chunk_selected = (mask_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
|
||||
# 累积
|
||||
self._debug_selected += chunk_selected
|
||||
self._debug_total += chunk_total
|
||||
|
||||
# 打印当前累积的 density
|
||||
if self._debug_total > 0:
|
||||
density = self._debug_selected / self._debug_total
|
||||
logger.info(f"[DEBUG Offload Layer0] 累积 density: {density:.4f} "
|
||||
f"(selected={self._debug_selected}, total={self._debug_total}, k_len={total_k_len}, "
|
||||
f"mask_shape={mask_chunk.shape}, q_offset={q_offset_blocks})")
|
||||
|
||||
# DEBUG: 跳过正常 offload 逻辑,直接返回所有 blocks
|
||||
return available_blocks
|
||||
else:
|
||||
# DEBUG: 非 Layer 0 也跳过正常 offload 逻辑
|
||||
return available_blocks
|
||||
|
||||
# ============================================================
|
||||
# Step 3: Get current chunk K and compute its attn_scores
|
||||
@@ -656,14 +727,16 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
q_start_bsa_block = historical_k_bsa_blocks # Q starts after historical K
|
||||
|
||||
with nvtx.range("xattn_find_blocks"):
|
||||
# 对于历史 K 的选择,使用 causal=False 因为历史 K 都在当前 Q 之前
|
||||
# current_index=0 避免超出 block_sums 的 K 维度
|
||||
mask = find_blocks_chunked(
|
||||
block_sums,
|
||||
current_index=q_start_bsa_block, # Q's position in BSA blocks
|
||||
current_index=0,
|
||||
threshold=self.threshold,
|
||||
num_to_choose=None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=True, # Causal for block-level mask
|
||||
mode="both",
|
||||
causal=False,
|
||||
)
|
||||
# mask shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks]
|
||||
|
||||
@@ -676,47 +749,13 @@ class XAttentionBSAPolicy(SparsePolicy):
|
||||
valid_q_bsa = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
valid_curr_k_bsa = (curr_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||
|
||||
# 7a: Record historical blocks density
|
||||
# IMPORTANT: For historical blocks, apply causal mask to match GPU-only density calculation!
|
||||
# Q block i (global position = q_start_bsa_block + i) can see historical K block j
|
||||
# only if j <= q_start_bsa_block + i (causal constraint)
|
||||
mask_historical = mask[:, :, :valid_q_bsa, :historical_k_bsa_blocks]
|
||||
# 7a: Record historical blocks density (暂时禁用,使用 DEBUG 输出代替)
|
||||
# if historical_k_bsa_blocks > 0:
|
||||
# ... DensityObserver.record_counts ...
|
||||
|
||||
if historical_k_bsa_blocks > 0:
|
||||
# Create causal mask for historical blocks
|
||||
# Q_global[i] = q_start_bsa_block + i, K[j] = j
|
||||
# Causal: j <= Q_global[i] => j <= q_start_bsa_block + i
|
||||
q_global_indices = torch.arange(valid_q_bsa, device=mask.device) + q_start_bsa_block
|
||||
k_indices = torch.arange(historical_k_bsa_blocks, device=mask.device)
|
||||
# Q at position q_global_indices[i] can see K at position k_indices[j] if k_indices[j] <= q_global_indices[i]
|
||||
causal_mask_historical = k_indices.unsqueeze(0) <= q_global_indices.unsqueeze(1) # [valid_q_bsa, historical_k_bsa_blocks]
|
||||
|
||||
# Count positions within causal mask only
|
||||
total_historical_causal = causal_mask_historical.sum().item() * B * H
|
||||
selected_historical = (mask_historical & causal_mask_historical.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
|
||||
if total_historical_causal > 0:
|
||||
DensityObserver.record_counts(layer_id, selected_historical, total_historical_causal)
|
||||
|
||||
# 7b: Record current chunk density (causal, to align with GPU-only mode)
|
||||
# Current chunk is the portion after historical blocks
|
||||
if valid_curr_k_bsa > 0:
|
||||
# Extract current chunk mask (only valid portion, not padded)
|
||||
mask_current = mask[:, :, :valid_q_bsa, historical_k_bsa_blocks:historical_k_bsa_blocks + valid_curr_k_bsa]
|
||||
|
||||
q_dim = mask_current.shape[2]
|
||||
k_dim = mask_current.shape[3]
|
||||
|
||||
# Create causal mask (lower triangular)
|
||||
# For current chunk: Q[i] can see K[j] where j <= i (standard causal)
|
||||
causal_mask = torch.tril(torch.ones(q_dim, k_dim, device=mask.device, dtype=torch.bool))
|
||||
|
||||
# Count positions within causal mask only
|
||||
total_current_causal = causal_mask.sum().item() * B * H
|
||||
selected_current = (mask_current & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
|
||||
if total_current_causal > 0:
|
||||
DensityObserver.record_counts(layer_id, selected_current, total_current_causal)
|
||||
# 7b: Record current chunk density (暂时禁用)
|
||||
# if valid_curr_k_bsa > 0:
|
||||
# ... DensityObserver.record_counts ...
|
||||
|
||||
# Step 7.5: Save historical mask to pre-allocated buffer for compute_chunked_prefill
|
||||
# Use full Q_bsa (padded) for buffer, not valid_q_bsa
|
||||
|
||||
Reference in New Issue
Block a user