[fix] Fixed kvcache offload bugs.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
import triton
|
||||
@@ -6,6 +7,8 @@ import triton.language as tl
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def store_kvcache_kernel(
|
||||
@@ -97,13 +100,16 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with Ping-Pong dual buffer for chunked prefill.
|
||||
Compute attention with 三区域 GPU buffer for chunked prefill.
|
||||
|
||||
For chunked prefill:
|
||||
1. Load previous KV from CPU using Ping-Pong (if any previous chunks)
|
||||
1. Load previous KV from CPU using Compute/Prefetch区 (if any previous chunks)
|
||||
2. Compute attention against previous KV chunks (no causal mask)
|
||||
3. Compute attention against current chunk's KV (causal)
|
||||
4. Merge all results using online softmax
|
||||
|
||||
三区域设计保证:当前chunk的KV在Compute区,previous KV从CPU加载到Prefetch区,
|
||||
不会发生写入和加载区域重叠的问题。
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
@@ -116,7 +122,7 @@ class Attention(nn.Module):
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
|
||||
# Load previous KV from CPU using Ping-Pong
|
||||
# Load previous KV from CPU using Compute/Prefetch区
|
||||
# Note: context.offload_engine is actually HybridKVCacheManager
|
||||
kvcache_manager = context.offload_engine
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
@@ -127,41 +133,42 @@ class Attention(nn.Module):
|
||||
|
||||
if cpu_block_table:
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
ping_size = offload_engine.ping_size
|
||||
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
|
||||
current_buffer = "ping"
|
||||
# 使用 Prefetch区 来加载 previous KV(不会与当前 Compute区 冲突)
|
||||
prefetch_size = offload_engine.num_prefetch_blocks
|
||||
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
|
||||
use_compute = True # 交替使用 Compute区 和 Prefetch区
|
||||
|
||||
# Prefetch first chunk to Ping buffer
|
||||
first_chunk_end = min(ping_size, len(cpu_block_table))
|
||||
# 首先将 previous KV 加载到 Prefetch区
|
||||
# Only layer 0 triggers the load (loads ALL layers at once)
|
||||
first_chunk_end = min(prefetch_size, len(cpu_block_table))
|
||||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
||||
offload_engine.load_to_ping(first_chunk_ids)
|
||||
if self.layer_id == 0:
|
||||
offload_engine.load_to_prefetch(first_chunk_ids)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * ping_size
|
||||
end = min(start + ping_size, len(cpu_block_table))
|
||||
start = chunk_idx * prefetch_size
|
||||
end = min(start + prefetch_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
|
||||
# Prefetch next chunk to OTHER buffer
|
||||
if chunk_idx + 1 < num_chunks:
|
||||
# Prefetch next chunk to other buffer (if exists)
|
||||
# Only layer 0 triggers the load
|
||||
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
||||
next_start = end
|
||||
next_end = min(next_start + ping_size, len(cpu_block_table))
|
||||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
if current_buffer == "ping":
|
||||
offload_engine.load_to_pong(next_chunk_ids)
|
||||
if use_compute:
|
||||
# 当前在 Prefetch区,下一个加载到 Compute区(如果有空间)
|
||||
# 注意:Compute区 此时已写入当前chunk的KV,不能覆盖
|
||||
# 所以这里我们使用简单的同步策略:等待当前完成后再加载
|
||||
pass # 简化版本:不进行双缓冲,只用 Prefetch区
|
||||
else:
|
||||
offload_engine.load_to_ping(next_chunk_ids)
|
||||
offload_engine.load_to_prefetch(next_chunk_ids)
|
||||
|
||||
# Wait for current buffer and get KV
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping()
|
||||
prev_k, prev_v = offload_engine.get_kv_for_ping_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
else:
|
||||
offload_engine.wait_pong()
|
||||
prev_k, prev_v = offload_engine.get_kv_for_pong_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
# Wait for Prefetch区 and get KV
|
||||
offload_engine.wait_prefetch()
|
||||
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
|
||||
# Compute attention against this chunk (no causal mask)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
@@ -178,8 +185,12 @@ class Attention(nn.Module):
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
# Switch buffer
|
||||
current_buffer = "pong" if current_buffer == "ping" else "ping"
|
||||
# Load next chunk to Prefetch区 (if exists)
|
||||
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
||||
next_start = end
|
||||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
offload_engine.load_to_prefetch(next_chunk_ids)
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
@@ -207,13 +218,16 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute decode attention with Ping-Pong dual buffer.
|
||||
Compute decode attention with 三区域 GPU buffer.
|
||||
|
||||
All KV is stored on CPU. Uses Ping-Pong buffers on GPU:
|
||||
1. Load first chunk to Ping buffer
|
||||
2. While computing on current buffer, prefetch next chunk to other buffer
|
||||
3. Alternate between Ping and Pong buffers
|
||||
4. Merge attention outputs using online softmax (LSE)
|
||||
All KV is stored on CPU. Uses Compute区 buffer on GPU:
|
||||
1. Load chunk to Compute区
|
||||
2. Compute attention
|
||||
3. Repeat for all chunks
|
||||
4. Finally, attend to Decode区 (slot 0) which contains the new token's KV
|
||||
5. Merge all attention outputs using online softmax (LSE)
|
||||
|
||||
关键:新token的KV在Decode区(slot 0),不会被Compute区的加载覆盖。
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
@@ -227,51 +241,37 @@ class Attention(nn.Module):
|
||||
|
||||
# Get all CPU blocks for this sequence
|
||||
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if self.layer_id == 0:
|
||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
||||
|
||||
# Get the actual offload_engine for Ping-Pong operations
|
||||
# Get the actual offload_engine for 三区域 operations
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
|
||||
# Calculate chunk info
|
||||
ping_size = offload_engine.ping_size
|
||||
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
|
||||
# Calculate chunk info using Compute区
|
||||
compute_size = offload_engine.num_compute_blocks
|
||||
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size
|
||||
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
current_buffer = "ping"
|
||||
|
||||
# Prefetch first chunk to Ping buffer (loads all layers at once)
|
||||
first_chunk_end = min(ping_size, len(cpu_block_table))
|
||||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
||||
offload_engine.load_to_ping(first_chunk_ids)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * ping_size
|
||||
end = min(start + ping_size, len(cpu_block_table))
|
||||
start = chunk_idx * compute_size
|
||||
end = min(start + compute_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
chunk_ids = cpu_block_table[start:end]
|
||||
|
||||
# Prefetch next chunk to OTHER buffer (overlapped with current computation)
|
||||
if chunk_idx + 1 < num_chunks:
|
||||
next_start = end
|
||||
next_end = min(next_start + ping_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
if current_buffer == "ping":
|
||||
offload_engine.load_to_pong(next_chunk_ids)
|
||||
else:
|
||||
offload_engine.load_to_ping(next_chunk_ids)
|
||||
# Load this chunk to Compute区
|
||||
# Only layer 0 triggers the load (loads ALL layers at once)
|
||||
if self.layer_id == 0:
|
||||
offload_engine.load_to_compute(chunk_ids)
|
||||
|
||||
# Wait for current buffer to be ready and get KV
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping()
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_ping_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
else:
|
||||
offload_engine.wait_pong()
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_pong_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
# Wait for Compute区 to be ready and get KV
|
||||
offload_engine.wait_compute()
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
@@ -286,8 +286,21 @@ class Attention(nn.Module):
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||
|
||||
# Switch buffer for next iteration
|
||||
current_buffer = "pong" if current_buffer == "ping" else "ping"
|
||||
# Now attend to Decode区 (contains the new token's KV)
|
||||
# This is the token being decoded - only 1 token at position pos_in_block
|
||||
pos_in_block = context.decode_pos_in_block
|
||||
decode_k, decode_v = offload_engine.get_kv_for_decode_slot(self.layer_id, pos_in_block)
|
||||
decode_o, decode_lse = flash_attn_with_lse(
|
||||
q_batched, decode_k, decode_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc = decode_o
|
||||
else:
|
||||
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||||
|
||||
if o_acc is None:
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
Reference in New Issue
Block a user