[WIP] fixing attention compute error.

This commit is contained in:
Zijie Tian
2025-12-30 00:31:48 +08:00
parent bf4c63c7ec
commit 89f8020d38
12 changed files with 2175 additions and 103 deletions

View File

@@ -169,9 +169,11 @@ class Attention(nn.Module):
else:
# Use ring buffer pipeline
o_acc, lse_acc = self._ring_buffer_pipeline_load(
q_batched, cpu_block_table, load_slots, offload_engine
q_batched, cpu_block_table, load_slots, offload_engine,
current_chunk_idx
)
# Compute attention against current chunk's KV (with causal mask)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
@@ -187,11 +189,18 @@ class Attention(nn.Module):
if o_acc is None:
final_o = current_o
else:
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
# reading it on the default stream for the merge operation.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -205,24 +214,27 @@ class Attention(nn.Module):
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
o_acc, lse_acc = None, None
compute_stream = offload_engine.compute_stream
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Load to slot 0 (single slot)
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(0, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
# IMPORTANT: Must use compute_stream to match wait_slot_layer
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
@@ -232,6 +244,7 @@ class Attention(nn.Module):
cpu_block_table: list,
load_slots: list,
offload_engine,
current_chunk_idx: int = -1,
):
"""
Ring buffer async pipeline loading with double buffering.
@@ -269,22 +282,26 @@ class Attention(nn.Module):
if pipeline_depth == 1:
# Only 1 slot available, cannot pipeline - use synchronous mode
# IMPORTANT: Must use compute_stream to match synchronization in
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
slot = load_slots[0]
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
offload_engine.wait_slot_layer(slot, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot, self.layer_id)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot, self.layer_id)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
# N-way pipeline: use ALL available slots for maximum overlap
@@ -378,12 +395,13 @@ class Attention(nn.Module):
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq
# Get all CPU blocks for this sequence
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
# Get only PREFILLED CPU blocks (exclude the current decode block)
# The decode block's KV is still in GPU decode_slot, not yet offloaded to CPU
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if self.layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None:
@@ -401,12 +419,17 @@ class Attention(nn.Module):
)
offload_engine = kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
# Chunk size = capacity of each double buffer region (compute/prefetch)
# Each region uses half of decode_load_slots
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
# Check if double buffering is possible (need at least 2 separate regions)
# With only 1 load slot, compute and prefetch regions overlap -> can't double buffer
can_double_buffer = len(offload_engine.decode_load_slots) >= 2
o_acc = None
lse_acc = None
@@ -422,49 +445,53 @@ class Attention(nn.Module):
end = min(start + chunk_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Wait for current buffer to be ready
if use_compute:
offload_engine.wait_compute_layer(self.layer_id)
else:
offload_engine.wait_prefetch_layer(self.layer_id)
# Wait for current buffer to be ready on compute_stream
# The load runs on transfer_stream_main, compute runs on compute_stream
compute_stream.wait_stream(offload_engine.transfer_stream_main)
# Trigger async prefetch of next chunk to the OTHER buffer
# This overlaps transfer with current chunk's computation
# All computation on explicit compute_stream
with torch.cuda.stream(compute_stream):
# Get KV from current buffer FIRST, before prefetching overwrites it
if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
)
else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Trigger async prefetch/load of next chunk to the OTHER buffer
# This happens AFTER attention completes, so the data is no longer needed
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + chunk_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if use_compute:
# Current in Compute, prefetch next to Prefetch region
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
if can_double_buffer:
if use_compute:
# Current in Compute, prefetch next to Prefetch region
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
else:
# Current in Prefetch, prefetch next to Compute region
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
else:
# Current in Prefetch, prefetch next to Compute region
# Sync fallback: load next chunk to same slot (always compute region)
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
# Get KV from current buffer
if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
)
else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Swap buffers for next iteration
# Swap buffers for next iteration (only matters if can_double_buffer)
use_compute = not use_compute
# Now attend to Decode region (contains accumulated decode tokens)
@@ -472,24 +499,29 @@ class Attention(nn.Module):
start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1
if num_accumulated > 0:
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0)
with torch.cuda.stream(compute_stream):
if num_accumulated > 0:
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")
# Sync back to default stream before returning
# Caller expects result to be ready on default stream
torch.cuda.default_stream().wait_stream(compute_stream)
return o_acc