From 829b311c028e5513c238bcef2f6679e96a9b7ca6 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Thu, 5 Feb 2026 01:30:23 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20stream=20synchronization?= =?UTF-8?q?=20for=20XAttention=20estimate=20kernels=20in=20offload=20mode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Wrap all compute kernels in select_blocks with compute_stream context (Pass 1 historical blocks, Pass 1 current chunk, Step 2 merge, Pass 2 historical blocks, Pass 2 current chunk, Step 4 block selection) - Fix K data mismatch between Pass 1 and Pass 2 by ensuring wait_slot_layer syncs with compute_stream where kernels actually run - Remove STRONG SYNC code from offload_engine.py (now handled by events) - Remove debug print statements and torch.save code - Consolidate fallback conditions in compute_with_xattn - Change default chunk_size from 16384 to 4096 for density alignment The bug caused Pass 1 and Pass 2 to see different K data from the same CPU block because compute kernels ran on default stream while wait_slot_layer only synced compute_stream. Co-Authored-By: Claude Opus 4.5 --- nanovllm/kvcache/sparse/xattn_bsa.py | 323 +++++++++++++-------------- 1 file changed, 153 insertions(+), 170 deletions(-) diff --git a/nanovllm/kvcache/sparse/xattn_bsa.py b/nanovllm/kvcache/sparse/xattn_bsa.py index a0c098b..f94d27e 100644 --- a/nanovllm/kvcache/sparse/xattn_bsa.py +++ b/nanovllm/kvcache/sparse/xattn_bsa.py @@ -96,7 +96,7 @@ class XAttentionBSAPolicy(SparsePolicy): self, threshold: float = 0.95, # High threshold for accuracy testing stride: int = 8, - chunk_size: int = 16384, + chunk_size: int = 4096, # Match offload Q chunk size for density alignment block_size: int = 128, samples_per_chunk: int = 128, use_triton: bool = True, @@ -289,9 +289,11 @@ class XAttentionBSAPolicy(SparsePolicy): Returns: Attention output [total_q, num_heads, head_dim] """ - # When block_tables is provided (paged KV cache / prefix cache), - # fallback to flash_attn as XAttention expects contiguous K, V - if block_tables is not None: + # Fallback to flash attention when: + # 1. block_tables provided (paged KV cache / prefix cache) - XAttention expects contiguous K, V + # 2. BSA kernel not available + # 3. xattn_estimate not available + if block_tables is not None or not BSA_AVAILABLE or not XATTN_AVAILABLE: from flash_attn import flash_attn_varlen_func return flash_attn_varlen_func( q, k, v, @@ -304,32 +306,6 @@ class XAttentionBSAPolicy(SparsePolicy): block_table=block_tables, ) - if not BSA_AVAILABLE: - # Fallback to flash attention if BSA not available - from flash_attn import flash_attn_varlen_func - return flash_attn_varlen_func( - q, k, v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=softmax_scale, - causal=True, - ) - - if not XATTN_AVAILABLE: - # Fallback to flash attention if xattn not available - from flash_attn import flash_attn_varlen_func - return flash_attn_varlen_func( - q, k, v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=softmax_scale, - causal=True, - ) - from nanovllm.ops.xattn import xattn_estimate # Set DensityObserver mode on first layer @@ -477,8 +453,7 @@ class XAttentionBSAPolicy(SparsePolicy): causal_total = q_bk * (q_bk + 1) // 2 * mask_trimmed.shape[0] * mask_trimmed.shape[1] causal_mask = torch.tril(torch.ones(q_bk, k_bk, device=mask_trimmed.device, dtype=torch.bool)) selected = (mask_trimmed & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item() - logger.info(f"[DEBUG GPU-only Layer0] mask_shape={mask_trimmed.shape}, " - f"density={selected/causal_total:.6f}, selected={selected}, total={causal_total}") + DensityObserver.record(layer_id, mask_trimmed, causal=True) return output @@ -633,98 +608,108 @@ class XAttentionBSAPolicy(SparsePolicy): l_chunks = [] num_kv_chunks = num_historical_blocks + 1 # +1 for current chunk + # Get compute_stream for all compute kernels (like attention computation) + compute_stream = offload_engine.compute_stream + with nvtx.range("xattn_estimate_pass1"): slot = 0 # Process historical blocks (from CPU) for kv_chunk_idx, cpu_block_id in enumerate(available_blocks): - # Load K from CPU + # Load K from CPU (on slot_transfer_stream) offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) + # wait_slot_layer makes compute_stream wait for H2D transfer offload_engine.wait_slot_layer(slot) - k_block = offload_engine.get_k_for_slot(slot) # [1, block_size, num_kv_heads, head_dim] - K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim] + # All compute kernels run on compute_stream (like attention computation) + with torch.cuda.stream(compute_stream): + k_block = offload_engine.get_k_for_slot(slot) # [1, block_size, num_kv_heads, head_dim] + K_chunk = k_block.transpose(1, 2) # [1, num_kv_heads, block_size, head_dim] - # GQA expansion - num_kv_heads = K_chunk.shape[1] + # GQA expansion + num_kv_heads = K_chunk.shape[1] + if num_heads != num_kv_heads: + num_groups = num_heads // num_kv_heads + K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) + + # KV offset in reshaped space + kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped + + # Compute raw attention scores + attn_weights_kv = flat_group_gemm_fuse_reshape( + Q, K_chunk, self.stride, + chunk_start=chunk_start, + chunk_end=chunk_end, + is_causal=False, # K 不完整,不能在这里用 causal + ) + + # Compute partial stats (带 causal mask) + m_partial, l_partial = softmax_compute_partial_stats( + attn_weights_kv, + reshaped_block_size, + segment_size, + scale, + chunk_start=chunk_start, + kv_offset=kv_offset_reshaped, + is_causal=True, + ) + m_chunks.append(m_partial) + l_chunks.append(l_partial) + + offload_engine.record_slot_compute_done(slot) + del attn_weights_kv + + # Process current chunk K (already on GPU) on compute_stream + with torch.cuda.stream(compute_stream): + # k: [seq_len, num_kv_heads, head_dim] -> [1, num_kv_heads, seq_len, head_dim] + K_current = k.unsqueeze(0).transpose(1, 2) + + # GQA expansion for current chunk + num_kv_heads = K_current.shape[1] if num_heads != num_kv_heads: num_groups = num_heads // num_kv_heads - K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) + K_current = K_current.repeat_interleave(num_groups, dim=1) - # KV offset in reshaped space - kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped + # Pad current K to alignment + curr_k_len = K_current.shape[2] + padded_curr_k_len = ((curr_k_len + alignment - 1) // alignment) * alignment + if padded_curr_k_len != curr_k_len: + K_current = torch.nn.functional.pad(K_current, (0, 0, 0, padded_curr_k_len - curr_k_len), value=0) - # Compute raw attention scores - attn_weights_kv = flat_group_gemm_fuse_reshape( - Q, K_chunk, self.stride, + # KV offset for current chunk + kv_offset_current = num_historical_blocks * kv_chunk_reshaped + + # Compute attention scores for current chunk + attn_weights_curr = flat_group_gemm_fuse_reshape( + Q, K_current, self.stride, chunk_start=chunk_start, chunk_end=chunk_end, - is_causal=False, # K 不完整,不能在这里用 causal + is_causal=False, ) - # Compute partial stats (带 causal mask) - m_partial, l_partial = softmax_compute_partial_stats( - attn_weights_kv, + # Compute partial stats for current chunk + m_partial_curr, l_partial_curr = softmax_compute_partial_stats( + attn_weights_curr, reshaped_block_size, segment_size, scale, chunk_start=chunk_start, - kv_offset=kv_offset_reshaped, + kv_offset=kv_offset_current, is_causal=True, ) - m_chunks.append(m_partial) - l_chunks.append(l_partial) + m_chunks.append(m_partial_curr) + l_chunks.append(l_partial_curr) - offload_engine.record_slot_compute_done(slot) - del attn_weights_kv - - # Process current chunk K (already on GPU) - # k: [seq_len, num_kv_heads, head_dim] -> [1, num_kv_heads, seq_len, head_dim] - K_current = k.unsqueeze(0).transpose(1, 2) - - # GQA expansion for current chunk - num_kv_heads = K_current.shape[1] - if num_heads != num_kv_heads: - num_groups = num_heads // num_kv_heads - K_current = K_current.repeat_interleave(num_groups, dim=1) - - # Pad current K to alignment - curr_k_len = K_current.shape[2] - padded_curr_k_len = ((curr_k_len + alignment - 1) // alignment) * alignment - if padded_curr_k_len != curr_k_len: - K_current = torch.nn.functional.pad(K_current, (0, 0, 0, padded_curr_k_len - curr_k_len), value=0) - - # KV offset for current chunk - kv_offset_current = num_historical_blocks * kv_chunk_reshaped - - # Compute attention scores for current chunk - attn_weights_curr = flat_group_gemm_fuse_reshape( - Q, K_current, self.stride, - chunk_start=chunk_start, - chunk_end=chunk_end, - is_causal=False, - ) - - # Compute partial stats for current chunk - m_partial_curr, l_partial_curr = softmax_compute_partial_stats( - attn_weights_curr, - reshaped_block_size, - segment_size, - scale, - chunk_start=chunk_start, - kv_offset=kv_offset_current, - is_causal=True, - ) - m_chunks.append(m_partial_curr) - l_chunks.append(l_partial_curr) - del attn_weights_curr + del attn_weights_curr # ================================================================ - # Step 2: Merge all partial stats + # Step 2: Merge all partial stats (on compute_stream) # ================================================================ - with nvtx.range("xattn_estimate_merge"): - m_global, l_global = merge_softmax_stats(m_chunks, l_chunks) - del m_chunks, l_chunks + with torch.cuda.stream(compute_stream): + with nvtx.range("xattn_estimate_merge"): + m_global, l_global = merge_softmax_stats(m_chunks, l_chunks) + + del m_chunks, l_chunks # ================================================================ # Step 3: Second pass - normalize and compute block sums @@ -736,30 +721,61 @@ class XAttentionBSAPolicy(SparsePolicy): # Process historical blocks again for kv_chunk_idx, cpu_block_id in enumerate(available_blocks): + # Load K from CPU (on slot_transfer_stream) offload_engine.load_k_only_to_slot_layer(slot, layer_id, cpu_block_id, chunk_idx=cpu_block_id) + # wait_slot_layer makes compute_stream wait for H2D transfer offload_engine.wait_slot_layer(slot) - k_block = offload_engine.get_k_for_slot(slot) - K_chunk = k_block.transpose(1, 2) + # All compute kernels run on compute_stream + with torch.cuda.stream(compute_stream): + k_block = offload_engine.get_k_for_slot(slot) + K_chunk = k_block.transpose(1, 2) - num_kv_heads = K_chunk.shape[1] - if num_heads != num_kv_heads: - num_groups = num_heads // num_kv_heads - K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) + num_kv_heads = K_chunk.shape[1] + if num_heads != num_kv_heads: + num_groups = num_heads // num_kv_heads + K_chunk = K_chunk.repeat_interleave(num_groups, dim=1) - kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped + kv_offset_reshaped = kv_chunk_idx * kv_chunk_reshaped - # Recompute attention scores (trade-off: compute vs memory) - attn_weights_kv = flat_group_gemm_fuse_reshape( - Q, K_chunk, self.stride, + # Recompute attention scores (trade-off: compute vs memory) + attn_weights_kv = flat_group_gemm_fuse_reshape( + Q, K_chunk, self.stride, + chunk_start=chunk_start, + chunk_end=chunk_end, + is_causal=False, + ) + + # Normalize with global stats and compute block sums + block_sum_kv = softmax_normalize_and_block_sum( + attn_weights_kv, + m_global, + l_global, + reshaped_block_size, + segment_size, + chunk_start=chunk_start, + real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad, + scale=scale, + kv_offset=kv_offset_reshaped, + is_causal=True, + ) + attn_sum_per_kv.append(block_sum_kv) + + offload_engine.record_slot_compute_done(slot) + del attn_weights_kv + + # Process current chunk on compute_stream + with torch.cuda.stream(compute_stream): + # Recompute attention scores for current chunk + attn_weights_curr = flat_group_gemm_fuse_reshape( + Q, K_current, self.stride, chunk_start=chunk_start, chunk_end=chunk_end, is_causal=False, ) - # Normalize with global stats and compute block sums - block_sum_kv = softmax_normalize_and_block_sum( - attn_weights_kv, + block_sum_curr = softmax_normalize_and_block_sum( + attn_weights_curr, m_global, l_global, reshaped_block_size, @@ -767,67 +783,42 @@ class XAttentionBSAPolicy(SparsePolicy): chunk_start=chunk_start, real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad, scale=scale, - kv_offset=kv_offset_reshaped, + kv_offset=kv_offset_current, is_causal=True, ) - attn_sum_per_kv.append(block_sum_kv) - - offload_engine.record_slot_compute_done(slot) - del attn_weights_kv - - # Process current chunk - # Recompute attention scores for current chunk - attn_weights_curr = flat_group_gemm_fuse_reshape( - Q, K_current, self.stride, - chunk_start=chunk_start, - chunk_end=chunk_end, - is_causal=False, - ) - - block_sum_curr = softmax_normalize_and_block_sum( - attn_weights_curr, - m_global, - l_global, - reshaped_block_size, - segment_size, - chunk_start=chunk_start, - real_q_len=k_reshaped_seq_len - k_reshaped_num_to_pad, - scale=scale, - kv_offset=kv_offset_current, - is_causal=True, - ) - attn_sum_per_kv.append(block_sum_curr) - del attn_weights_curr, K_current + attn_sum_per_kv.append(block_sum_curr) + del attn_weights_curr, K_current # ================================================================ - # Step 4: Concatenate block sums and select blocks + # Step 4: Concatenate block sums and select blocks (on compute_stream) # ================================================================ - attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1) - del attn_sum_per_kv, m_global, l_global + with torch.cuda.stream(compute_stream): + attn_sum_concat = torch.cat(attn_sum_per_kv, dim=-1) + del attn_sum_per_kv, m_global, l_global - # Calculate q_block offset for find_blocks_chunked - # This is the number of BSA blocks before Q in the full sequence - num_blocks_per_chunk = q_reshaped_len // reshaped_block_size - current_index = k_block_num - q_block_num # Q starts at this BSA block index + # Calculate q_block offset for find_blocks_chunked + # This is the number of BSA blocks before Q in the full sequence + num_blocks_per_chunk = q_reshaped_len // reshaped_block_size + current_index = k_block_num - q_block_num # Q starts at this BSA block index - with nvtx.range("xattn_find_blocks"): - mask = find_blocks_chunked( - attn_sum_concat, - current_index=current_index, - threshold=self.threshold, - num_to_choose=None, - decoding=False, - mode="prefill", - causal=True, + with nvtx.range("xattn_find_blocks"): + mask = find_blocks_chunked( + attn_sum_concat, + current_index=current_index, + threshold=self.threshold, + num_to_choose=None, + decoding=False, + mode="prefill", + causal=True, + ) + + # Apply causal mask post-processing (same as xattn.py lines 1300-1306) + mask[:, :, -q_block_num:, -q_block_num:] = torch.where( + torch.tril(torch.ones(q_block_num, q_block_num, dtype=torch.bool, device=mask.device), diagonal=0), + mask[:, :, -q_block_num:, -q_block_num:], + False, ) - # Apply causal mask post-processing (same as xattn.py lines 1300-1306) - mask[:, :, -q_block_num:, -q_block_num:] = torch.where( - torch.tril(torch.ones(q_block_num, q_block_num, dtype=torch.bool, device=mask.device), diagonal=0), - mask[:, :, -q_block_num:, -q_block_num:], - False, - ) - # ================================================================ # Step 5: Record density (only on layer 0) # ================================================================ @@ -908,14 +899,6 @@ class XAttentionBSAPolicy(SparsePolicy): if available_blocks and available_blocks[-1] not in selected_block_ids: selected_block_ids.append(available_blocks[-1]) - # Record communication density - if available_blocks: - DensityObserver.record_comm_density( - layer_id, - selected_cpu_blocks=len(selected_block_ids), - total_cpu_blocks=len(available_blocks), - ) - # Update statistics (only for layer 0 to avoid overcounting) if layer_id == 0 and available_blocks: self._stats_total_available_blocks += len(available_blocks)