[WIP] fixing attention compute error.
This commit is contained in:
@@ -169,9 +169,11 @@ class Attention(nn.Module):
|
||||
else:
|
||||
# Use ring buffer pipeline
|
||||
o_acc, lse_acc = self._ring_buffer_pipeline_load(
|
||||
q_batched, cpu_block_table, load_slots, offload_engine
|
||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||
current_chunk_idx
|
||||
)
|
||||
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
@@ -187,11 +189,18 @@ class Attention(nn.Module):
|
||||
if o_acc is None:
|
||||
final_o = current_o
|
||||
else:
|
||||
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
|
||||
# reading it on the default stream for the merge operation.
|
||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
|
||||
|
||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||
|
||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
@@ -205,24 +214,27 @@ class Attention(nn.Module):
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
o_acc, lse_acc = None, None
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
||||
# Load to slot 0 (single slot)
|
||||
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(0, self.layer_id)
|
||||
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
|
||||
# IMPORTANT: Must use compute_stream to match wait_slot_layer
|
||||
with torch.cuda.stream(compute_stream):
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
|
||||
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
return o_acc, lse_acc
|
||||
|
||||
@@ -232,6 +244,7 @@ class Attention(nn.Module):
|
||||
cpu_block_table: list,
|
||||
load_slots: list,
|
||||
offload_engine,
|
||||
current_chunk_idx: int = -1,
|
||||
):
|
||||
"""
|
||||
Ring buffer async pipeline loading with double buffering.
|
||||
@@ -269,22 +282,26 @@ class Attention(nn.Module):
|
||||
|
||||
if pipeline_depth == 1:
|
||||
# Only 1 slot available, cannot pipeline - use synchronous mode
|
||||
# IMPORTANT: Must use compute_stream to match synchronization in
|
||||
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
|
||||
slot = load_slots[0]
|
||||
compute_stream = offload_engine.compute_stream
|
||||
for block_idx in range(num_blocks):
|
||||
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
|
||||
offload_engine.wait_slot_layer(slot, self.layer_id)
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
# Record compute done so next load can safely reuse this slot
|
||||
offload_engine.record_slot_compute_done(slot, self.layer_id)
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
with torch.cuda.stream(compute_stream):
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
# Record compute done so next load can safely reuse this slot
|
||||
offload_engine.record_slot_compute_done(slot, self.layer_id)
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
return o_acc, lse_acc
|
||||
|
||||
# N-way pipeline: use ALL available slots for maximum overlap
|
||||
@@ -378,12 +395,13 @@ class Attention(nn.Module):
|
||||
kvcache_manager = context.kvcache_manager
|
||||
seq = context.chunked_seq
|
||||
|
||||
# Get all CPU blocks for this sequence
|
||||
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
||||
# The decode block's KV is still in GPU decode_slot, not yet offloaded to CPU
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
if self.layer_id == 0:
|
||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
if kvcache_manager.sparse_policy is not None:
|
||||
@@ -401,12 +419,17 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Chunk size = capacity of each double buffer region (compute/prefetch)
|
||||
# Each region uses half of decode_load_slots
|
||||
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
|
||||
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
|
||||
|
||||
# Check if double buffering is possible (need at least 2 separate regions)
|
||||
# With only 1 load slot, compute and prefetch regions overlap -> can't double buffer
|
||||
can_double_buffer = len(offload_engine.decode_load_slots) >= 2
|
||||
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
|
||||
@@ -422,49 +445,53 @@ class Attention(nn.Module):
|
||||
end = min(start + chunk_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
|
||||
# Wait for current buffer to be ready
|
||||
if use_compute:
|
||||
offload_engine.wait_compute_layer(self.layer_id)
|
||||
else:
|
||||
offload_engine.wait_prefetch_layer(self.layer_id)
|
||||
# Wait for current buffer to be ready on compute_stream
|
||||
# The load runs on transfer_stream_main, compute runs on compute_stream
|
||||
compute_stream.wait_stream(offload_engine.transfer_stream_main)
|
||||
|
||||
# Trigger async prefetch of next chunk to the OTHER buffer
|
||||
# This overlaps transfer with current chunk's computation
|
||||
# All computation on explicit compute_stream
|
||||
with torch.cuda.stream(compute_stream):
|
||||
# Get KV from current buffer FIRST, before prefetching overwrites it
|
||||
if use_compute:
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
else:
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
q_batched, k_chunk, v_chunk,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = o_chunk, lse_chunk
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||
|
||||
# Trigger async prefetch/load of next chunk to the OTHER buffer
|
||||
# This happens AFTER attention completes, so the data is no longer needed
|
||||
if chunk_idx + 1 < num_chunks:
|
||||
next_start = end
|
||||
next_end = min(next_start + chunk_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
if use_compute:
|
||||
# Current in Compute, prefetch next to Prefetch region
|
||||
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
|
||||
if can_double_buffer:
|
||||
if use_compute:
|
||||
# Current in Compute, prefetch next to Prefetch region
|
||||
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
|
||||
else:
|
||||
# Current in Prefetch, prefetch next to Compute region
|
||||
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
|
||||
else:
|
||||
# Current in Prefetch, prefetch next to Compute region
|
||||
# Sync fallback: load next chunk to same slot (always compute region)
|
||||
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
|
||||
|
||||
# Get KV from current buffer
|
||||
if use_compute:
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
else:
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
q_batched, k_chunk, v_chunk,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = o_chunk, lse_chunk
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||
|
||||
# Swap buffers for next iteration
|
||||
# Swap buffers for next iteration (only matters if can_double_buffer)
|
||||
use_compute = not use_compute
|
||||
|
||||
# Now attend to Decode region (contains accumulated decode tokens)
|
||||
@@ -472,24 +499,29 @@ class Attention(nn.Module):
|
||||
start_pos = context.decode_start_pos_in_block
|
||||
num_accumulated = pos_in_block - start_pos + 1
|
||||
|
||||
if num_accumulated > 0:
|
||||
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||
decode_k = decode_k.unsqueeze(0)
|
||||
decode_v = decode_v.unsqueeze(0)
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if num_accumulated > 0:
|
||||
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
|
||||
decode_k = decode_k.unsqueeze(0)
|
||||
decode_v = decode_v.unsqueeze(0)
|
||||
|
||||
decode_o, decode_lse = flash_attn_with_lse(
|
||||
q_batched, decode_k, decode_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
decode_o, decode_lse = flash_attn_with_lse(
|
||||
q_batched, decode_k, decode_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
if o_acc is None:
|
||||
o_acc = decode_o
|
||||
else:
|
||||
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||||
if o_acc is None:
|
||||
o_acc = decode_o
|
||||
else:
|
||||
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||||
|
||||
if o_acc is None:
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
# Sync back to default stream before returning
|
||||
# Caller expects result to be ready on default stream
|
||||
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||
|
||||
return o_acc
|
||||
|
||||
Reference in New Issue
Block a user