[fix] Fixed kvcache offload bugs.

This commit is contained in:
Zijie Tian
2025-12-10 22:34:00 +08:00
parent 190df5f70d
commit e85c2b4776
7 changed files with 409 additions and 156 deletions

View File

@@ -1,3 +1,4 @@
import logging
import torch
from torch import nn
import triton
@@ -6,6 +7,8 @@ import triton.language as tl
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
logger = logging.getLogger(__name__)
@triton.jit
def store_kvcache_kernel(
@@ -97,13 +100,16 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute attention with Ping-Pong dual buffer for chunked prefill.
Compute attention with 三区域 GPU buffer for chunked prefill.
For chunked prefill:
1. Load previous KV from CPU using Ping-Pong (if any previous chunks)
1. Load previous KV from CPU using Compute/Prefetch区 (if any previous chunks)
2. Compute attention against previous KV chunks (no causal mask)
3. Compute attention against current chunk's KV (causal)
4. Merge all results using online softmax
三区域设计保证当前chunk的KV在Compute区previous KV从CPU加载到Prefetch区
不会发生写入和加载区域重叠的问题。
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -116,7 +122,7 @@ class Attention(nn.Module):
o_acc = None
lse_acc = None
# Load previous KV from CPU using Ping-Pong
# Load previous KV from CPU using Compute/Prefetch区
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
@@ -127,41 +133,42 @@ class Attention(nn.Module):
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
ping_size = offload_engine.ping_size
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
current_buffer = "ping"
# 使用 Prefetch区 来加载 previous KV不会与当前 Compute区 冲突)
prefetch_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
use_compute = True # 交替使用 Compute区 和 Prefetch区
# Prefetch first chunk to Ping buffer
first_chunk_end = min(ping_size, len(cpu_block_table))
# 首先将 previous KV 加载到 Prefetch区
# Only layer 0 triggers the load (loads ALL layers at once)
first_chunk_end = min(prefetch_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
offload_engine.load_to_ping(first_chunk_ids)
if self.layer_id == 0:
offload_engine.load_to_prefetch(first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * ping_size
end = min(start + ping_size, len(cpu_block_table))
start = chunk_idx * prefetch_size
end = min(start + prefetch_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Prefetch next chunk to OTHER buffer
if chunk_idx + 1 < num_chunks:
# Prefetch next chunk to other buffer (if exists)
# Only layer 0 triggers the load
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
next_start = end
next_end = min(next_start + ping_size, len(cpu_block_table))
next_end = min(next_start + prefetch_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if current_buffer == "ping":
offload_engine.load_to_pong(next_chunk_ids)
if use_compute:
# 当前在 Prefetch区下一个加载到 Compute区如果有空间
# 注意Compute区 此时已写入当前chunk的KV不能覆盖
# 所以这里我们使用简单的同步策略:等待当前完成后再加载
pass # 简化版本:不进行双缓冲,只用 Prefetch区
else:
offload_engine.load_to_ping(next_chunk_ids)
offload_engine.load_to_prefetch(next_chunk_ids)
# Wait for current buffer and get KV
if current_buffer == "ping":
offload_engine.wait_ping()
prev_k, prev_v = offload_engine.get_kv_for_ping_slots(
self.layer_id, num_blocks_in_chunk
)
else:
offload_engine.wait_pong()
prev_k, prev_v = offload_engine.get_kv_for_pong_slots(
self.layer_id, num_blocks_in_chunk
)
# Wait for Prefetch区 and get KV
offload_engine.wait_prefetch()
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
# Compute attention against this chunk (no causal mask)
prev_o, prev_lse = flash_attn_with_lse(
@@ -178,8 +185,12 @@ class Attention(nn.Module):
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Switch buffer
current_buffer = "pong" if current_buffer == "ping" else "ping"
# Load next chunk to Prefetch区 (if exists)
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
next_start = end
next_end = min(next_start + prefetch_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
offload_engine.load_to_prefetch(next_chunk_ids)
# Compute attention against current chunk's KV (with causal mask)
current_o, current_lse = flash_attn_with_lse(
@@ -207,13 +218,16 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention with Ping-Pong dual buffer.
Compute decode attention with 三区域 GPU buffer.
All KV is stored on CPU. Uses Ping-Pong buffers on GPU:
1. Load first chunk to Ping buffer
2. While computing on current buffer, prefetch next chunk to other buffer
3. Alternate between Ping and Pong buffers
4. Merge attention outputs using online softmax (LSE)
All KV is stored on CPU. Uses Compute区 buffer on GPU:
1. Load chunk to Compute区
2. Compute attention
3. Repeat for all chunks
4. Finally, attend to Decode区 (slot 0) which contains the new token's KV
5. Merge all attention outputs using online softmax (LSE)
关键新token的KV在Decode区(slot 0)不会被Compute区的加载覆盖。
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -227,51 +241,37 @@ class Attention(nn.Module):
# Get all CPU blocks for this sequence
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
if self.layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
# Get the actual offload_engine for Ping-Pong operations
# Get the actual offload_engine for 三区域 operations
offload_engine = kvcache_manager.offload_engine
# Calculate chunk info
ping_size = offload_engine.ping_size
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
# Calculate chunk info using Compute区
compute_size = offload_engine.num_compute_blocks
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size
o_acc = None
lse_acc = None
current_buffer = "ping"
# Prefetch first chunk to Ping buffer (loads all layers at once)
first_chunk_end = min(ping_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
offload_engine.load_to_ping(first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * ping_size
end = min(start + ping_size, len(cpu_block_table))
start = chunk_idx * compute_size
end = min(start + compute_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
chunk_ids = cpu_block_table[start:end]
# Prefetch next chunk to OTHER buffer (overlapped with current computation)
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + ping_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if current_buffer == "ping":
offload_engine.load_to_pong(next_chunk_ids)
else:
offload_engine.load_to_ping(next_chunk_ids)
# Load this chunk to Compute区
# Only layer 0 triggers the load (loads ALL layers at once)
if self.layer_id == 0:
offload_engine.load_to_compute(chunk_ids)
# Wait for current buffer to be ready and get KV
if current_buffer == "ping":
offload_engine.wait_ping()
k_chunk, v_chunk = offload_engine.get_kv_for_ping_slots(
self.layer_id, num_blocks_in_chunk
)
else:
offload_engine.wait_pong()
k_chunk, v_chunk = offload_engine.get_kv_for_pong_slots(
self.layer_id, num_blocks_in_chunk
)
# Wait for Compute区 to be ready and get KV
offload_engine.wait_compute()
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
@@ -286,8 +286,21 @@ class Attention(nn.Module):
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Switch buffer for next iteration
current_buffer = "pong" if current_buffer == "ping" else "ping"
# Now attend to Decode区 (contains the new token's KV)
# This is the token being decoded - only 1 token at position pos_in_block
pos_in_block = context.decode_pos_in_block
decode_k, decode_v = offload_engine.get_kv_for_decode_slot(self.layer_id, pos_in_block)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")