diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index 92adc48..afc3dca 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -147,8 +147,10 @@ class XAttentionBSAPolicy(SparsePolicy): self._selected_cpu_indices: List[int] = [] self._bsa_per_cpu: int = 0 # BSA blocks per CPU block - #> Debug: store all K cache + #> Debug: store all K cache and density counts self._debug_k_full: torch.Tensor | None = None + self._debug_selected: int = 0 # 累积的 selected blocks + self._debug_total: int = 0 # 累积的 total blocks def alloc_policy_metadata( self, @@ -202,8 +204,10 @@ class XAttentionBSAPolicy(SparsePolicy): memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024) logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB") - #DEBUG : buffer for save all K cache. + #DEBUG : buffer for save all K cache self._debug_k_full = torch.empty((1, num_heads, max_seq_len, head_dim), dtype=dtype, device=device) + self._debug_selected = 0 + self._debug_total = 0 # ========================================================================= # GPU-only methods (non-chunked) @@ -395,6 +399,15 @@ class XAttentionBSAPolicy(SparsePolicy): ) # Record density for all layers via DensityObserver + if layer_id == 0: + # DEBUG: 打印 GPU-only Layer 0 的 mask 详情 + q_bk = mask_trimmed.shape[2] + k_bk = mask_trimmed.shape[3] + causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1] + causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool)) + selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() + logger.info(f"[DEBUG GPU-only Layer0] mask_shape={mask_trimmed.shape}, " + f"density={selected/causal_total:.6f}, selected={selected}, total={causal_total}") DensityObserver.record(layer_id, mask_trimmed, causal=True) return output @@ -567,9 +580,67 @@ class XAttentionBSAPolicy(SparsePolicy): k_repeated = k.repeat_interleave(num_groups, dim=1).unsqueeze(0).transpose(1, 2) # [1, num_heads, historical_k_len, head_dim] self._debug_k_full[:, :, historical_k_len:historical_k_len + q_len, :].copy_(k_repeated) - + + # ============================================================ + # DEBUG: 累积 selected/total counts (仅 layer 0) + # 使用完整 K 调用 xattn_estimate,与 GPU-only 逻辑一致 + # ============================================================ if layer_id == 0: - __import__('pdb').set_trace() + from nanovllm.ops.xattn import xattn_estimate + + total_k_len = historical_k_len + q_len + K_full = self._debug_k_full[:, :, :total_k_len, :] + + # 用当前 Q chunk 和累积的 K 调用 xattn_estimate + # 设置 chunk_size 为 q_len 的最小对齐值 (stride * BLOCK_M = 8 * 128 = 1024) + alignment = self.stride * 128 + aligned_chunk_size = ((q_len + alignment - 1) // alignment) * alignment + # DEBUG: 使用固定 threshold 测试 + _, mask_chunk = xattn_estimate( + Q[:, :, :q_len, :], # 当前 Q chunk + K_full, # 累积的 K + block_size=self.BSA_BLOCK_SIZE, + stride=self.stride, + threshold=self.threshold, # DEBUG: 使用传入的 threshold + chunk_size=aligned_chunk_size, # 对齐的 chunk_size + causal=True, + ) + + # 计算有效的 block 数量(排除 padding) + valid_q_blocks = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE + valid_k_blocks = (total_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE + + # 裁剪 mask 到有效区域 + mask_valid = mask_chunk[:, :, :valid_q_blocks, :valid_k_blocks] + + # 计算当前 chunk 的 selected/total (考虑 causal,考虑 Q 偏移量) + q_blocks = valid_q_blocks + k_blocks = valid_k_blocks + # Q 从位置 (k_blocks - q_blocks) 开始,所以 Q block i 实际位置是 i + offset + # Q block i (实际位置 i+offset) 可以看到 K block 0 到 i+offset + q_offset_blocks = k_blocks - q_blocks + indices = torch.arange(k_blocks, device=mask_valid.device).unsqueeze(0) # [1, k_blocks] + q_indices = torch.arange(q_blocks, device=mask_valid.device).unsqueeze(1) # [q_blocks, 1] + causal_mask = indices <= (q_indices + q_offset_blocks) # [q_blocks, k_blocks] + chunk_total = causal_mask.sum().item() * mask_valid.shape[0] * mask_valid.shape[1] + chunk_selected = (mask_valid & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() + + # 累积 + self._debug_selected += chunk_selected + self._debug_total += chunk_total + + # 打印当前累积的 density + if self._debug_total > 0: + density = self._debug_selected / self._debug_total + logger.info(f"[DEBUG Offload Layer0] 累积 density: {density:.4f} " + f"(selected={self._debug_selected}, total={self._debug_total}, k_len={total_k_len}, " + f"mask_shape={mask_chunk.shape}, q_offset={q_offset_blocks})") + + # DEBUG: 跳过正常 offload 逻辑,直接返回所有 blocks + return available_blocks + else: + # DEBUG: 非 Layer 0 也跳过正常 offload 逻辑 + return available_blocks # ============================================================ # Step 3: Get current chunk K and compute its attn_scores @@ -656,14 +727,16 @@ class XAttentionBSAPolicy(SparsePolicy): q_start_bsa_block = historical_k_bsa_blocks # Q starts after historical K with nvtx.range("xattn_find_blocks"): + # 对于历史 K 的选择,使用 causal=False 因为历史 K 都在当前 Q 之前 + # current_index=0 避免超出 block_sums 的 K 维度 mask = find_blocks_chunked( block_sums, - current_index=q_start_bsa_block, # Q's position in BSA blocks + current_index=0, threshold=self.threshold, num_to_choose=None, decoding=False, - mode="prefill", - causal=True, # Causal for block-level mask + mode="both", + causal=False, ) # mask shape: [batch, heads, q_bsa_blocks, total_k_bsa_blocks] @@ -676,47 +749,13 @@ class XAttentionBSAPolicy(SparsePolicy): valid_q_bsa = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE valid_curr_k_bsa = (curr_k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE - # 7a: Record historical blocks density - # IMPORTANT: For historical blocks, apply causal mask to match GPU-only density calculation! - # Q block i (global position = q_start_bsa_block + i) can see historical K block j - # only if j <= q_start_bsa_block + i (causal constraint) - mask_historical = mask[:, :, :valid_q_bsa, :historical_k_bsa_blocks] + # 7a: Record historical blocks density (暂时禁用,使用 DEBUG 输出代替) + # if historical_k_bsa_blocks > 0: + # ... DensityObserver.record_counts ... - if historical_k_bsa_blocks > 0: - # Create causal mask for historical blocks - # Q_global[i] = q_start_bsa_block + i, K[j] = j - # Causal: j <= Q_global[i] => j <= q_start_bsa_block + i - q_global_indices = torch.arange(valid_q_bsa, device=mask.device) + q_start_bsa_block - k_indices = torch.arange(historical_k_bsa_blocks, device=mask.device) - # Q at position q_global_indices[i] can see K at position k_indices[j] if k_indices[j] <= q_global_indices[i] - causal_mask_historical = k_indices.unsqueeze(0) <= q_global_indices.unsqueeze(1) # [valid_q_bsa, historical_k_bsa_blocks] - - # Count positions within causal mask only - total_historical_causal = causal_mask_historical.sum().item() * B * H - selected_historical = (mask_historical & causal_mask_historical.unsqueeze(0).unsqueeze(0)).sum().item() - - if total_historical_causal > 0: - DensityObserver.record_counts(layer_id, selected_historical, total_historical_causal) - - # 7b: Record current chunk density (causal, to align with GPU-only mode) - # Current chunk is the portion after historical blocks - if valid_curr_k_bsa > 0: - # Extract current chunk mask (only valid portion, not padded) - mask_current = mask[:, :, :valid_q_bsa, historical_k_bsa_blocks:historical_k_bsa_blocks + valid_curr_k_bsa] - - q_dim = mask_current.shape[2] - k_dim = mask_current.shape[3] - - # Create causal mask (lower triangular) - # For current chunk: Q[i] can see K[j] where j <= i (standard causal) - causal_mask = torch.tril(torch.ones(q_dim, k_dim, device=mask.device, dtype=torch.bool)) - - # Count positions within causal mask only - total_current_causal = causal_mask.sum().item() * B * H - selected_current = (mask_current & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() - - if total_current_causal > 0: - DensityObserver.record_counts(layer_id, selected_current, total_current_causal) + # 7b: Record current chunk density (暂时禁用) + # if valid_curr_k_bsa > 0: + # ... DensityObserver.record_counts ... # Step 7.5: Save historical mask to pre-allocated buffer for compute_chunked_prefill # Use full Q_bsa (padded) for buffer, not valid_q_bsa