[WIP] need to fix model to normally decode.

This commit is contained in:
Zijie Tian
2026-01-01 05:18:27 +08:00
parent 62b8a63314
commit 74ee6d0895
3 changed files with 317 additions and 123 deletions

View File

@@ -118,6 +118,24 @@ class OffloadEngine:
dtype=dtype, device="cuda"
)
# ========== Per-layer decode buffer ==========
# During decode, all layers share decode_slot (no layer dimension in GPU cache).
# This causes accumulated tokens to be overwritten by each layer.
# Solution: Maintain separate per-layer buffers for decode tokens.
# Shape: [num_layers, block_size, kv_heads, head_dim]
# Memory: num_layers * block_size * kv_heads * head_dim * dtype_size
# e.g., 28 * 1024 * 8 * 128 * 2 = 58.7 MB (acceptable)
self.decode_k_buffer = torch.zeros(
num_layers, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.decode_v_buffer = torch.zeros(
num_layers, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
# ========== Fixed-address CPU KV cache (pinned memory) ==========
self.k_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,

View File

@@ -87,6 +87,15 @@ class Attention(nn.Module):
else: # decode
if context.is_chunked_prefill:
# Chunked decode: need to load all KV from CPU+GPU
# Store current decode token to per-layer decode buffer
# This is needed because GPU cache has no layer dimension,
# so all layers would overwrite each other in decode_slot.
kvcache_manager = context.kvcache_manager
offload_engine = kvcache_manager.offload_engine
pos_in_block = context.decode_pos_in_block
# k, v shape: [1, kv_heads, head_dim]
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
@@ -390,25 +399,17 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention with double-buffering using decode_load_slots.
Compute decode attention using ring buffer pipeline (same as prefill).
Decode uses:
- decode_slot (slot[0]): writes new token's KV
- decode_load_slots (slots[1:]): load previous chunks from CPU
Uses the same loading mechanism as _chunked_prefill_attention:
- Load one block at a time from CPU to GPU slot
- Compute attention for each block
- Merge results using online softmax
- Finally merge with decode buffer (accumulated decode tokens)
Pipeline design:
- First half of decode_load_slots: 'compute' buffer
- Second half: 'prefetch' buffer
- Double-buffer between them for async overlap
Timeline:
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│Load C0→buf0 │ │Load C1→buf1 │ │Load C2→buf0 │ ...
└─────────────┘ └─────────────┘ └─────────────┘
↘ ↘ ↘
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Attn(C0) │ │ Attn(C1) │ │ Attn(C2) │
└─────────────┘ └─────────────┘ └─────────────┘
This approach is simpler and proven correct (prefill tests pass).
The only difference from prefill is the additional decode buffer
that stores new tokens generated during decode.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -419,7 +420,6 @@ class Attention(nn.Module):
seq = context.chunked_seq
# Get only PREFILLED CPU blocks (exclude the current decode block)
# The decode block's KV is still in GPU decode_slot, not yet offloaded to CPU
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if self.layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
@@ -427,12 +427,12 @@ class Attention(nn.Module):
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last block
# prefill_len = total prefilled tokens (current decode token not yet in CPU)
# Note: For chunked prefill, each block is exactly block_size tokens
# The cpu_block_table only contains full prefill blocks
block_size = kvcache_manager.block_size
prefill_len = len(seq) - 1 # Exclude current decode token
last_block_valid_tokens = prefill_len % block_size
if last_block_valid_tokens == 0 and prefill_len > 0:
last_block_valid_tokens = block_size # Last block is full
num_prefill_blocks = len(cpu_block_table)
# All prefill blocks are full (block_size tokens each)
last_block_valid_tokens = block_size
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None:
@@ -440,7 +440,7 @@ class Attention(nn.Module):
query_chunk_idx=0,
num_query_chunks=1,
layer_id=self.layer_id,
query=q_batched, # Decode provides query for query-aware selection
query=q_batched,
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
@@ -450,104 +450,28 @@ class Attention(nn.Module):
)
offload_engine = kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
load_slots = offload_engine.decode_load_slots # Available slots for loading
# Chunk size = capacity of each double buffer region (compute/prefetch)
# Each region uses half of decode_load_slots
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
q_batched, cpu_block_table, load_slots, offload_engine,
block_size, last_block_valid_tokens
)
# Check if double buffering is possible (need at least 2 separate regions)
# With only 1 load slot, compute and prefetch regions overlap -> can't double buffer
can_double_buffer = len(offload_engine.decode_load_slots) >= 2
o_acc = None
lse_acc = None
# Double buffering state: True = use Compute region, False = use Prefetch region
use_compute = True
# Pre-load first chunk to Compute region (async)
first_chunk_ids = cpu_block_table[:min(chunk_size, len(cpu_block_table))]
offload_engine.load_to_compute_layer(self.layer_id, first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Wait for current buffer to be ready on compute_stream
# The load runs on transfer_stream_main, compute runs on compute_stream
compute_stream.wait_stream(offload_engine.transfer_stream_main)
# All computation on explicit compute_stream
with torch.cuda.stream(compute_stream):
# Get KV from current buffer FIRST, before prefetching overwrites it
if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute(num_blocks_in_chunk)
else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk)
# Handle partial last block: slice to only include valid tokens
# This is critical because the rest of the block contains stale data
is_last_chunk = (end == len(cpu_block_table))
if is_last_chunk and last_block_valid_tokens < block_size:
# Calculate total valid tokens in this chunk
# All blocks except the last are full, last block has last_block_valid_tokens
full_blocks = num_blocks_in_chunk - 1
valid_tokens = full_blocks * block_size + last_block_valid_tokens
# Slice KV: [batch, seqlen, heads, dim] -> [batch, valid_tokens, heads, dim]
k_chunk = k_chunk[:, :valid_tokens, :, :]
v_chunk = v_chunk[:, :valid_tokens, :, :]
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Trigger async prefetch/load of next chunk to the OTHER buffer
# This happens AFTER attention completes, so the data is no longer needed
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + chunk_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if can_double_buffer:
if use_compute:
# Current in Compute, prefetch next to Prefetch region
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
else:
# Current in Prefetch, prefetch next to Compute region
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
else:
# Sync fallback: load next chunk to same slot (always compute region)
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
# Swap buffers for next iteration (only matters if can_double_buffer)
use_compute = not use_compute
# Now attend to Decode region (contains accumulated decode tokens)
# Now attend to accumulated decode tokens from per-layer decode buffer
pos_in_block = context.decode_pos_in_block
start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1
# IMPORTANT: Sync compute_stream with default stream before reading decode_slot
# store_kvcache writes to decode_slot on default stream (before entering this function)
# We need to ensure that write is complete before reading on compute_stream
# Sync compute_stream with default stream before reading decode_buffer
compute_stream = offload_engine.compute_stream
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
if num_accumulated > 0:
# GPU cache has no layer dimension
decode_k = offload_engine.k_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
# Read from per-layer decode buffer
decode_k = offload_engine.decode_k_buffer[self.layer_id, start_pos:pos_in_block+1]
decode_v = offload_engine.decode_v_buffer[self.layer_id, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0)
@@ -566,7 +490,82 @@ class Attention(nn.Module):
raise RuntimeError("Chunked decode attention failed: no KV available")
# Sync back to default stream before returning
# Caller expects result to be ready on default stream
torch.cuda.default_stream().wait_stream(compute_stream)
return o_acc
def _decode_ring_buffer_pipeline(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
load_slots: list,
offload_engine,
block_size: int,
last_block_valid_tokens: int,
):
"""
Ring buffer pipeline for decode prefill loading (same mechanism as prefill).
Loads one block at a time, computes attention, and merges results.
Uses the same load_to_slot_layer / wait_slot_layer / get_kv_for_slot
methods as prefill for proven correctness.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
if not load_slots:
return None, None
o_acc, lse_acc = None, None
num_slots = len(load_slots)
compute_stream = offload_engine.compute_stream
# Phase 1: Pre-load up to num_slots blocks
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
# Phase 2: Process blocks with pipeline
for block_idx in range(num_blocks):
current_slot = load_slots[block_idx % num_slots]
cpu_block_id = cpu_block_table[block_idx]
# Wait for current slot's transfer to complete
offload_engine.wait_slot_layer(current_slot)
with torch.cuda.stream(compute_stream):
# Get KV from slot
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
# Handle partial last block
is_last_block = (block_idx == num_blocks - 1)
if is_last_block and last_block_valid_tokens < block_size:
prev_k = prev_k[:, :last_block_valid_tokens, :, :]
prev_v = prev_v[:, :last_block_valid_tokens, :, :]
# Compute attention
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done for slot reuse
offload_engine.record_slot_compute_done(current_slot)
# Start loading next block (pipeline)
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
# Merge with accumulated
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc