🐛 fix: stream synchronization for XAttention estimate kernels in offload mode

- Wrap all compute kernels in select_blocks with compute_stream context
  (Pass 1 historical blocks, Pass 1 current chunk, Step 2 merge,
   Pass 2 historical blocks, Pass 2 current chunk, Step 4 block selection)
- Fix K data mismatch between Pass 1 and Pass 2 by ensuring wait_slot_layer
  syncs with compute_stream where kernels actually run
- Remove STRONG SYNC code from offload_engine.py (now handled by events)
- Remove debug print statements and torch.save code
- Consolidate fallback conditions in compute_with_xattn
- Change default chunk_size from 16384 to 4096 for density alignment

The bug caused Pass 1 and Pass 2 to see different K data from the same
CPU block because compute kernels ran on default stream while
wait_slot_layer only synced compute_stream.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Zijie Tian
2026-02-05 01:30:23 +08:00
parent dd0472aea8
commit 829b311c02

View File

@@ -96,7 +96,7 @@ class XAttentionBSAPolicy(SparsePolicy):
self,
threshold: float = 0.95, # High threshold for accuracy testing
stride: int = 8,
chunk_size: int = 16384,
chunk_size: int = 4096, # Match offload Q chunk size for density alignment
block_size: int = 128,
samples_per_chunk: int = 128,
use_triton: bool = True,
@@ -289,9 +289,11 @@ class XAttentionBSAPolicy(SparsePolicy):
Returns:
Attention output [total_q, num_heads, head_dim]
"""
# When block_tables is provided (paged KV cache / prefix cache),
# fallback to flash_attn as XAttention expects contiguous K, V
if block_tables is not None:
# Fallback to flash attention when:
# 1. block_tables provided (paged KV cache / prefix cache) - XAttention expects contiguous K, V
# 2. BSA kernel not available
# 3. xattn_estimate not available
if block_tables is not None or not BSA_AVAILABLE or not XATTN_AVAILABLE:
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
@@ -304,32 +306,6 @@ class XAttentionBSAPolicy(SparsePolicy):
block_table=block_tables,
)
if not BSA_AVAILABLE:
# Fallback to flash attention if BSA not available
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
)
if not XATTN_AVAILABLE:
# Fallback to flash attention if xattn not available
from flash_attn import flash_attn_varlen_func
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=True,
)
from nanovllm.ops.xattn import xattn_estimate
# Set DensityObserver mode on first layer
@@ -477,8 +453,7 @@ class XAttentionBSAPolicy(SparsePolicy):
causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1]
causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool))
selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
logger.info(f"[DEBUG GPU-only Layer0] mask_shape={mask_trimmed.shape}, "
f"density={selected/causal_total:.6f}, selected={selected}, total={causal_total}")
DensityObserver.record(layer_id, mask_trimmed, causal=True)
return output
@@ -633,15 +608,21 @@ class XAttentionBSAPolicy(SparsePolicy):
l_chunks = []
num_kv_chunks = num_historical_blocks + 1 # +1 for current chunk
# Get compute_stream for all compute kernels (like attention computation)
compute_stream = offload_engine.compute_stream
with nvtx.range("xattn_estimate_pass1"):
slot = 0
# Process historical blocks (from CPU)
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
# Load K from CPU
# Load K from CPU (on slot_transfer_stream)
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
# wait_slot_layer makes compute_stream wait for H2D transfer
offload_engine.wait_slot_layer(slot)
# All compute kernels run on compute_stream (like attention computation)
with torch.cuda.stream(compute_stream):
k_block = offload_engine.get_k_for_slot(slot) # [1, block_size, num_kv_heads, head_dim]
K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim]
@@ -678,7 +659,8 @@ class XAttentionBSAPolicy(SparsePolicy):
offload_engine.record_slot_compute_done(slot)
del attn_weights_kv
# Process current chunk K (already on GPU)
# Process current chunk K (already on GPU) on compute_stream
with torch.cuda.stream(compute_stream):
# k: [seq_len, num_kv_heads, head_dim] -> [1, num_kv_heads, seq_len, head_dim]
K_current = k.unsqueeze(0).transpose(1, 2)
@@ -717,13 +699,16 @@ class XAttentionBSAPolicy(SparsePolicy):
)
m_chunks.append(m_partial_curr)
l_chunks.append(l_partial_curr)
del attn_weights_curr
# ================================================================
# Step 2: Merge all partial stats
# Step 2: Merge all partial stats (on compute_stream)
# ================================================================
with torch.cuda.stream(compute_stream):
with nvtx.range("xattn_estimate_merge"):
m_global, l_global = merge_softmax_stats(m_chunks, l_chunks)
del m_chunks, l_chunks
# ================================================================
@@ -736,9 +721,13 @@ class XAttentionBSAPolicy(SparsePolicy):
# Process historical blocks again
for kv_chunk_idx, cpu_block_id in enumerate(available_blocks):
# Load K from CPU (on slot_transfer_stream)
offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id)
# wait_slot_layer makes compute_stream wait for H2D transfer
offload_engine.wait_slot_layer(slot)
# All compute kernels run on compute_stream
with torch.cuda.stream(compute_stream):
k_block = offload_engine.get_k_for_slot(slot)
K_chunk = k_block.transpose(1, 2)
@@ -775,7 +764,8 @@ class XAttentionBSAPolicy(SparsePolicy):
offload_engine.record_slot_compute_done(slot)
del attn_weights_kv
# Process current chunk
# Process current chunk on compute_stream
with torch.cuda.stream(compute_stream):
# Recompute attention scores for current chunk
attn_weights_curr = flat_group_gemm_fuse_reshape(
Q, K_current, self.stride,
@@ -800,8 +790,9 @@ class XAttentionBSAPolicy(SparsePolicy):
del attn_weights_curr, K_current
# ================================================================
# Step 4: Concatenate block sums and select blocks
# Step 4: Concatenate block sums and select blocks (on compute_stream)
# ================================================================
with torch.cuda.stream(compute_stream):
attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1)
del attn_sum_per_kv, m_global, l_global
@@ -908,14 +899,6 @@ class XAttentionBSAPolicy(SparsePolicy):
if available_blocks and available_blocks[-1] not in selected_block_ids:
selected_block_ids.append(available_blocks[-1])
# Record communication density
if available_blocks:
DensityObserver.record_comm_density(
layer_id,
selected_cpu_blocks=len(selected_block_ids),
total_cpu_blocks=len(available_blocks),
)
# Update statistics (only for layer 0 to avoid overcounting)
if layer_id == 0 and available_blocks:
self._stats_total_available_blocks += len(available_blocks)