[fix] Fixed kvcache offload bugs.

This commit is contained in:
Zijie Tian
2025-12-10 22:34:00 +08:00
parent 190df5f70d
commit e85c2b4776
7 changed files with 409 additions and 156 deletions

View File

@@ -22,6 +22,7 @@ class Config:
offload_policy: str = "lru" # "lru", "fifo", or full class path
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
num_prefetch_blocks: int = 2 # Prefetch区的block数量用于三区域GPU Buffer设计
# Computed fields for offload (set in __post_init__ or by ModelRunner)
num_gpu_kvcache_blocks: int = -1

View File

@@ -148,19 +148,39 @@ class ModelRunner:
dtype=hf_config.torch_dtype,
)
# Log KV cache allocation info
# Log KV cache allocation info with detailed per-token breakdown
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
total_memory_mb = gpu_memory_mb + cpu_memory_mb
# Calculate per-token KV cache usage
# KV per token = 2 (K+V) * num_layers * kv_heads * head_dim * dtype_size
dtype_size = 2 if hf_config.torch_dtype in [torch.float16, torch.bfloat16] else 4
per_token_kv_bytes = 2 * hf_config.num_hidden_layers * num_kv_heads * head_dim * dtype_size
per_token_kv_kb = per_token_kv_bytes / 1024
logger.info(
f"KV Cache per-token: {per_token_kv_kb:.2f}KB "
f"(2 * {hf_config.num_hidden_layers}layers * {num_kv_heads}kv_heads * {head_dim}head_dim * {dtype_size}bytes)"
)
logger.info(
f"KV Cache per-block: {block_bytes / (1024**2):.2f}MB "
f"({per_token_kv_kb:.2f}KB * {self.block_size}tokens)"
)
if config.enable_cpu_offload:
ping_size = config.num_gpu_kvcache_blocks // 2
tokens_per_ping = ping_size * self.block_size
logger.info(
f"KV Cache allocated (Ping-Pong mode): "
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), "
f"Total={total_memory_mb:.1f}MB, "
f"block_size={self.block_size}, "
f"ping_size={config.num_gpu_kvcache_blocks // 2}"
f"Total={total_memory_mb:.1f}MB"
)
logger.info(
f"Ping-Pong config: ping_size={ping_size} blocks, "
f"tokens_per_chunk={tokens_per_ping}, "
f"block_size={self.block_size}"
)
else:
logger.info(
@@ -392,12 +412,12 @@ class ModelRunner:
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
"""
Check if Ping-Pong mode should be used.
Check if 三区域 mode should be used.
Use Ping-Pong when:
Use 三区域 when:
- CPU offload is enabled
- There are blocks on CPU (either allocated there or offloaded)
- Sequence exceeds GPU capacity
- Sequence exceeds GPU Compute区 capacity
"""
if not hasattr(self.kvcache_manager, 'offload_engine'):
return False
@@ -409,12 +429,12 @@ class ModelRunner:
# Check if any blocks are on CPU
cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq)
if cpu_blocks:
# Has CPU blocks - use Ping-Pong
# Has CPU blocks - use 三区域
return True
# Check if sequence needs more blocks than GPU can hold
ping_size = self.kvcache_manager.offload_engine.ping_size
if seq.num_blocks > ping_size:
# Check if sequence needs more blocks than GPU Compute区 can hold
compute_size = self.kvcache_manager.offload_engine.num_compute_blocks
if seq.num_blocks > compute_size:
# Needs chunked processing
return True
@@ -610,29 +630,28 @@ class ModelRunner:
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill with Ping-Pong dual buffer (CPU is primary storage).
Run prefill with 三区域 GPU buffer (CPU is primary storage).
Flow:
1. All blocks are allocated to CPU (primary storage)
2. Process tokens in chunks using Ping/Pong GPU buffers
3. After each chunk, offload from GPU to CPU
4. Alternate between Ping and Pong buffers
2. Process tokens in chunks using Compute区 GPU buffer
3. After each chunk, offload from Compute区 to CPU
4. Prefetch区 用于加载 previous KV如果有的话
"""
import sys
assert len(seqs) == 1, "Ping-Pong prefill only supports single sequence"
assert len(seqs) == 1, "三区域 prefill only supports single sequence"
seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine
ping_size = offload_engine.ping_size
tokens_per_chunk = ping_size * self.block_size
compute_size = offload_engine.num_compute_blocks
tokens_per_chunk = compute_size * self.block_size
total_tokens = len(seq)
print(f"[Ping-Pong Prefill] Starting: {total_tokens} tokens, "
f"ping_size={ping_size} blocks, chunk={tokens_per_chunk} tokens",
print(f"[三区域 Prefill] Starting: {total_tokens} tokens, "
f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens",
file=sys.stderr)
current_buffer = "ping"
chunk_num = 0
logits = None
processed_tokens = 0
@@ -651,15 +670,13 @@ class ModelRunner:
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
num_blocks = end_block_idx - start_block_idx
print(f"[Ping-Pong Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
f"blocks {start_block_idx}-{end_block_idx-1}, buffer={current_buffer}",
print(f"[三区域 Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
f"blocks {start_block_idx}-{end_block_idx-1}, "
f"compute_slots={offload_engine.compute_slots[:num_blocks]}",
file=sys.stderr)
# Get GPU slots for this chunk (Ping or Pong buffer)
if current_buffer == "ping":
gpu_slots = offload_engine.ping_slots[:num_blocks]
else:
gpu_slots = offload_engine.pong_slots[:num_blocks]
# Get GPU slots for this chunk (使用 Compute区)
gpu_slots = offload_engine.compute_slots[:num_blocks]
# Prepare inputs
input_ids, positions = self._prepare_pingpong_chunk(
@@ -678,24 +695,19 @@ class ModelRunner:
logical_id = seq.block_table[i]
self.kvcache_manager.prefilled_blocks.add(logical_id)
# Offload this chunk from GPU to CPU (async)
# Offload this chunk from Compute区 to CPU (async)
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
offload_engine.offload_buffer_to_cpu(current_buffer, chunk_cpu_blocks)
offload_engine.offload_compute_to_cpu(chunk_cpu_blocks)
# Switch buffer for next chunk
if current_buffer == "ping":
offload_engine.wait_ping_offload_done()
current_buffer = "pong"
else:
offload_engine.wait_pong_offload_done()
current_buffer = "ping"
# Wait for offload to complete before next chunk
offload_engine.wait_all_offload_done()
processed_tokens = chunk_end
# Wait for all offloads to complete
offload_engine.wait_all_offload_done()
print(f"[Ping-Pong Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
print(f"[三区域 Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
# Sample from last logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
@@ -764,25 +776,26 @@ class ModelRunner:
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with Ping-Pong dual buffer.
Run decode with 三区域 GPU buffer.
All KV is on CPU. Uses Ping-Pong to load KV chunks and compute attention.
New token's KV is written to GPU then offloaded to CPU.
All KV is on CPU. Uses Decode区 to write new KV, Compute/Prefetch区 to load KV chunks.
New token's KV is written to Decode区 (slot 0) then offloaded to CPU.
关键Decode区 永远不会被 Compute/Prefetch 覆盖专门用于写入新KV。
"""
assert len(seqs) == 1, "Ping-Pong decode only supports single sequence"
assert len(seqs) == 1, "三区域 decode only supports single sequence"
seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine
# Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
# Get write slot for new KV (will use last slot of the buffer used for final chunk)
write_slot = self.kvcache_manager.get_write_slot_for_pingpong(seq)
# Calculate position in block for slot mapping
last_block_idx = seq.num_blocks - 1
# 使用 Decode区 (slot 0) 写入新 KV
decode_slot = offload_engine.decode_slot # = 0
pos_in_block = (len(seq) - 1) % self.block_size
slot = write_slot * self.block_size + pos_in_block
slot = decode_slot * self.block_size + pos_in_block
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
@@ -794,17 +807,18 @@ class ModelRunner:
is_chunked_prefill=True, # Use chunked attention path
offload_engine=self.kvcache_manager,
chunked_seq=seq,
decode_pos_in_block=pos_in_block,
)
# Run model forward pass
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()
# Offload new KV from write_slot to CPU
# Offload new KV from Decode区 to CPU
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
if last_cpu_block >= 0:
self.kvcache_manager.offload_engine.offload_slot_to_cpu(write_slot, last_cpu_block)
self.kvcache_manager.offload_engine.wait_all_offload_done()
offload_engine.offload_decode_slot(last_cpu_block)
offload_engine.wait_all_offload_done()
# Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None

View File

@@ -58,12 +58,14 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
from nanovllm.kvcache.policies import get_policy
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
num_prefetch_blocks = getattr(config, 'num_prefetch_blocks', 2)
return HybridKVCacheManager(
num_gpu_slots=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=config.kvcache_block_size,
policy=policy,
num_prefetch_blocks=num_prefetch_blocks,
)

View File

@@ -9,6 +9,7 @@ Key design for CUDA Graph compatibility:
5. Graph replay only needs index updates (tiny overhead)
"""
import logging
from collections import deque
from dataclasses import dataclass, field
from enum import Enum, auto
@@ -16,6 +17,8 @@ from typing import List, Tuple, Dict, Set, Optional
import torch
from torch import Tensor
logger = logging.getLogger(__name__)
from nanovllm.engine.sequence import Sequence
from nanovllm.kvcache.base_manager import KVCacheManager
from nanovllm.kvcache.offload_engine import OffloadEngine
@@ -82,6 +85,7 @@ class HybridKVCacheManager(KVCacheManager):
block_size: int,
policy: Optional[EvictionPolicy] = None,
cpu_primary: bool = True,
num_prefetch_blocks: int = 2,
):
"""
Initialize hybrid manager.
@@ -91,14 +95,16 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
block_size: Tokens per block
policy: Eviction policy (default: LRU)
cpu_primary: If True, use CPU as primary storage with Ping-Pong GPU buffer.
cpu_primary: If True, use CPU as primary storage with 三区域 GPU buffer.
If False, use GPU as primary with CPU as overflow (legacy mode).
num_prefetch_blocks: Number of prefetch blocks for 三区域 GPU buffer design
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
self.num_cpu_blocks = num_cpu_blocks
self.total_blocks = num_gpu_slots + num_cpu_blocks
self.cpu_primary = cpu_primary # Ping-Pong mode flag
self.cpu_primary = cpu_primary # 三区域 mode flag
self.num_prefetch_blocks = num_prefetch_blocks # 三区域设计参数
# Eviction policy
self.policy = policy or LRUPolicy()
@@ -156,6 +162,7 @@ class HybridKVCacheManager(KVCacheManager):
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=dtype,
num_prefetch_blocks=self.num_prefetch_blocks,
)
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
@@ -948,6 +955,10 @@ class HybridKVCacheManager(KVCacheManager):
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id)
logger.debug(
f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
f"returned cpu_blocks={cpu_blocks}"
)
return cpu_blocks
def load_prev_kv_for_layer(
@@ -1139,28 +1150,18 @@ class HybridKVCacheManager(KVCacheManager):
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
"""
获取 Ping-Pong decode 时新 KV 写入的 GPU slot。
获取三区域 decode 时新 KV 写入的 GPU slot。
策略:使用序列所需 chunks 数决定最后用的是 Ping 还是 Pong buffer
然后使用该 buffer 的最后一个 slot
在三区域设计中,永远使用 Decode区 (slot 0) 写入新 KV。
这样可以避免与 Compute/Prefetch区 的加载操作冲突
Args:
seq: 序列
Returns:
GPU slot ID
GPU slot ID (永远是 decode_slot = 0)
"""
cpu_blocks, _ = self.get_all_cpu_blocks(seq)
ping_size = self.offload_engine.ping_size
num_chunks = (len(cpu_blocks) + ping_size - 1) // ping_size if cpu_blocks else 0
# 最后一个 chunk 用的是哪个 buffer
if num_chunks % 2 == 1 or num_chunks == 0:
# 奇数个 chunk或0个最后用的是 ping
return self.offload_engine.ping_slots[-1]
else:
# 偶数个 chunk最后用的是 pong
return self.offload_engine.pong_slots[-1]
return self.offload_engine.decode_slot
def __repr__(self) -> str:
return (

View File

@@ -53,6 +53,7 @@ class OffloadEngine:
head_dim: int,
dtype: torch.dtype = torch.float16,
num_streams: int = 4,
num_prefetch_blocks: int = 2,
):
self.num_layers = num_layers
self.num_gpu_blocks = num_gpu_blocks
@@ -64,31 +65,59 @@ class OffloadEngine:
self.kv_dim = num_kv_heads * head_dim
self.block_numel = block_size * self.kv_dim
# ========== Ping-Pong 双缓冲配置 ==========
assert num_gpu_blocks >= 2, "Ping-Pong需要至少2个GPU blocks"
self.ping_size = num_gpu_blocks // 2
self.pong_size = num_gpu_blocks - self.ping_size
self.ping_slots = list(range(self.ping_size)) # [0, 1, 2, ...]
self.pong_slots = list(range(self.ping_size, num_gpu_blocks)) # [ping_size, ping_size+1, ...]
# ========== 三区域 GPU Buffer 配置 ==========
# 约束检查
assert num_gpu_blocks >= 3, \
f"至少需要3个GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
assert num_prefetch_blocks >= 1, \
f"至少需要1个prefetch block, got {num_prefetch_blocks}"
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
f"GPU blocks不足: 需要 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
# 三区域配置
# Decode区: [0] - 固定1个block用于写入新KV
self.decode_slot = 0
# Compute区: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
compute_start = 1
compute_end = num_gpu_blocks - num_prefetch_blocks
self.compute_slots = list(range(compute_start, compute_end))
self.num_compute_blocks = len(self.compute_slots)
# Prefetch区: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
prefetch_start = compute_end
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
self.num_prefetch_blocks = num_prefetch_blocks
self.num_gpu_slots = num_gpu_blocks # alias
# 保留旧的ping/pong属性以兼容后续会移除
self.ping_size = self.num_compute_blocks
self.pong_size = self.num_prefetch_blocks
self.ping_slots = self.compute_slots.copy()
self.pong_slots = self.prefetch_slots.copy()
logger.info(f"三区域 GPU Buffer: decode_slot={self.decode_slot}, "
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
# ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
self.k_cache_gpu = torch.empty(
# 使用 zeros 初始化以避免未初始化内存问题
self.k_cache_gpu = torch.zeros(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
self.v_cache_gpu = torch.empty(
self.v_cache_gpu = torch.zeros(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda"
)
# ========== Fixed-address CPU KV cache (pinned memory) ==========
self.k_cache_cpu = torch.empty(
self.k_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cpu", pin_memory=True
)
self.v_cache_cpu = torch.empty(
self.v_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cpu", pin_memory=True
)
@@ -111,14 +140,18 @@ class OffloadEngine:
self.compute_stream = torch.cuda.current_stream()
self._stream_idx = 0
# ========== Ping-Pong 专用 stream 和事件 ==========
self.pingpong_stream = torch.cuda.Stream() # 专用于Ping-Pong传输
# ========== 三区域专用 stream 和事件 ==========
self.transfer_stream_main = torch.cuda.Stream() # 主传输stream
# 同步事件 - 加载完成
self.ping_ready = torch.cuda.Event()
self.pong_ready = torch.cuda.Event()
# 同步事件 - 三区域加载完成
self.compute_ready = torch.cuda.Event()
self.prefetch_ready = torch.cuda.Event()
self.decode_offload_done = torch.cuda.Event()
# 同步事件 - offload完成
# 保留旧的ping/pong事件以兼容后续会移除
self.pingpong_stream = self.transfer_stream_main
self.ping_ready = self.compute_ready
self.pong_ready = self.prefetch_ready
self.ping_offload_done = torch.cuda.Event()
self.pong_offload_done = torch.cuda.Event()
@@ -535,7 +568,7 @@ class OffloadEngine:
f" kv_heads={self.num_kv_heads},\n"
f" head_dim={self.head_dim},\n"
f" dtype={self.dtype},\n"
f" ping_size={self.ping_size}, pong_size={self.pong_size},\n"
f" 三区域: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
f")"
@@ -742,4 +775,189 @@ class OffloadEngine:
v = self.v_cache_gpu[layer_id, gpu_slots]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
return k, v
# ========== 三区域 GPU Buffer 方法 ==========
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
"""
异步加载CPU blocks到Compute区。
Args:
cpu_block_ids: 要加载的CPU block IDs列表
"""
if not cpu_block_ids:
self.compute_ready.record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
logger.debug(f"Compute load: CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.compute_slots[i]
# 所有层一起复制
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.compute_ready.record(self.transfer_stream_main)
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
"""
异步加载CPU blocks到Prefetch区。
Args:
cpu_block_ids: 要加载的CPU block IDs列表
"""
if not cpu_block_ids:
self.prefetch_ready.record(self.transfer_stream_main)
return
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
logger.debug(f"Prefetch load: CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
with torch.cuda.stream(self.transfer_stream_main):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.prefetch_slots[i]
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.prefetch_ready.record(self.transfer_stream_main)
def wait_compute(self) -> None:
"""等待Compute区加载完成。"""
self.compute_stream.wait_event(self.compute_ready)
def wait_prefetch(self) -> None:
"""等待Prefetch区加载完成。"""
self.compute_stream.wait_event(self.prefetch_ready)
def swap_compute_prefetch(self) -> None:
"""交换Compute区和Prefetch区的角色。"""
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
# 同时更新旧的ping/pong slots以保持兼容
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
def offload_decode_slot(self, cpu_block_id: int) -> None:
"""
将Decode区的KV offload到CPU。
Args:
cpu_block_id: 目标CPU block ID
"""
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream)
self.k_cache_cpu[:, cpu_block_id].copy_(
self.k_cache_gpu[:, self.decode_slot], non_blocking=True
)
self.v_cache_cpu[:, cpu_block_id].copy_(
self.v_cache_gpu[:, self.decode_slot], non_blocking=True
)
self.decode_offload_done.record(self.transfer_stream_main)
def wait_decode_offload(self) -> None:
"""等待Decode区offload完成。"""
self.compute_stream.wait_event(self.decode_offload_done)
def get_kv_for_compute(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Compute区中指定数量blocks的KV。
Args:
layer_id: 层ID
num_blocks: 需要的block数量
Returns:
(k_cache, v_cache)shape: [1, num_blocks * block_size, kv_heads, head_dim]
"""
slots = self.compute_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slots]
# Reshape: [num_blocks, block_size, heads, dim] -> [1, num_blocks*block_size, heads, dim]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_prefetch(
self,
layer_id: int,
num_blocks: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Prefetch区中指定数量blocks的KV。
Args:
layer_id: 层ID
num_blocks: 需要的block数量
Returns:
(k_cache, v_cache)shape: [1, num_blocks * block_size, kv_heads, head_dim]
"""
slots = self.prefetch_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots]
v = self.v_cache_gpu[layer_id, slots]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_decode_slot(
self,
layer_id: int,
pos_in_block: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Decode区指定位置的KV用于decode时的新token
Args:
layer_id: 层ID
pos_in_block: token在block内的位置 (0 to block_size-1)
Returns:
(k_cache, v_cache)shape: [1, 1, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
k = k.unsqueeze(0) # [1, 1, heads, dim]
v = v.unsqueeze(0)
return k, v
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
"""
将Compute区的KV offload到CPU。
Args:
cpu_block_ids: 目标CPU block IDs列表
"""
if not cpu_block_ids:
return
num_to_offload = min(len(cpu_block_ids), len(self.compute_slots))
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
with torch.cuda.stream(self.transfer_stream_main):
# 等待计算完成
self.transfer_stream_main.wait_stream(self.compute_stream)
for i in range(num_to_offload):
gpu_slot = self.compute_slots[i]
cpu_id = cpu_block_ids[i]
self.k_cache_cpu[:, cpu_id].copy_(
self.k_cache_gpu[:, gpu_slot], non_blocking=True
)
self.v_cache_cpu[:, cpu_id].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True
)

View File

@@ -1,3 +1,4 @@
import logging
import torch
from torch import nn
import triton
@@ -6,6 +7,8 @@ import triton.language as tl
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
logger = logging.getLogger(__name__)
@triton.jit
def store_kvcache_kernel(
@@ -97,13 +100,16 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute attention with Ping-Pong dual buffer for chunked prefill.
Compute attention with 三区域 GPU buffer for chunked prefill.
For chunked prefill:
1. Load previous KV from CPU using Ping-Pong (if any previous chunks)
1. Load previous KV from CPU using Compute/Prefetch区 (if any previous chunks)
2. Compute attention against previous KV chunks (no causal mask)
3. Compute attention against current chunk's KV (causal)
4. Merge all results using online softmax
三区域设计保证当前chunk的KV在Compute区previous KV从CPU加载到Prefetch区
不会发生写入和加载区域重叠的问题。
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -116,7 +122,7 @@ class Attention(nn.Module):
o_acc = None
lse_acc = None
# Load previous KV from CPU using Ping-Pong
# Load previous KV from CPU using Compute/Prefetch区
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
@@ -127,41 +133,42 @@ class Attention(nn.Module):
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
ping_size = offload_engine.ping_size
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
current_buffer = "ping"
# 使用 Prefetch区 来加载 previous KV不会与当前 Compute区 冲突)
prefetch_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
use_compute = True # 交替使用 Compute区 和 Prefetch区
# Prefetch first chunk to Ping buffer
first_chunk_end = min(ping_size, len(cpu_block_table))
# 首先将 previous KV 加载到 Prefetch区
# Only layer 0 triggers the load (loads ALL layers at once)
first_chunk_end = min(prefetch_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
offload_engine.load_to_ping(first_chunk_ids)
if self.layer_id == 0:
offload_engine.load_to_prefetch(first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * ping_size
end = min(start + ping_size, len(cpu_block_table))
start = chunk_idx * prefetch_size
end = min(start + prefetch_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Prefetch next chunk to OTHER buffer
if chunk_idx + 1 < num_chunks:
# Prefetch next chunk to other buffer (if exists)
# Only layer 0 triggers the load
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
next_start = end
next_end = min(next_start + ping_size, len(cpu_block_table))
next_end = min(next_start + prefetch_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if current_buffer == "ping":
offload_engine.load_to_pong(next_chunk_ids)
if use_compute:
# 当前在 Prefetch区下一个加载到 Compute区如果有空间
# 注意Compute区 此时已写入当前chunk的KV不能覆盖
# 所以这里我们使用简单的同步策略:等待当前完成后再加载
pass # 简化版本:不进行双缓冲,只用 Prefetch区
else:
offload_engine.load_to_ping(next_chunk_ids)
offload_engine.load_to_prefetch(next_chunk_ids)
# Wait for current buffer and get KV
if current_buffer == "ping":
offload_engine.wait_ping()
prev_k, prev_v = offload_engine.get_kv_for_ping_slots(
self.layer_id, num_blocks_in_chunk
)
else:
offload_engine.wait_pong()
prev_k, prev_v = offload_engine.get_kv_for_pong_slots(
self.layer_id, num_blocks_in_chunk
)
# Wait for Prefetch区 and get KV
offload_engine.wait_prefetch()
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
# Compute attention against this chunk (no causal mask)
prev_o, prev_lse = flash_attn_with_lse(
@@ -178,8 +185,12 @@ class Attention(nn.Module):
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Switch buffer
current_buffer = "pong" if current_buffer == "ping" else "ping"
# Load next chunk to Prefetch区 (if exists)
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
next_start = end
next_end = min(next_start + prefetch_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
offload_engine.load_to_prefetch(next_chunk_ids)
# Compute attention against current chunk's KV (with causal mask)
current_o, current_lse = flash_attn_with_lse(
@@ -207,13 +218,16 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention with Ping-Pong dual buffer.
Compute decode attention with 三区域 GPU buffer.
All KV is stored on CPU. Uses Ping-Pong buffers on GPU:
1. Load first chunk to Ping buffer
2. While computing on current buffer, prefetch next chunk to other buffer
3. Alternate between Ping and Pong buffers
4. Merge attention outputs using online softmax (LSE)
All KV is stored on CPU. Uses Compute区 buffer on GPU:
1. Load chunk to Compute区
2. Compute attention
3. Repeat for all chunks
4. Finally, attend to Decode区 (slot 0) which contains the new token's KV
5. Merge all attention outputs using online softmax (LSE)
关键新token的KV在Decode区(slot 0)不会被Compute区的加载覆盖。
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -227,51 +241,37 @@ class Attention(nn.Module):
# Get all CPU blocks for this sequence
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
if self.layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
# Get the actual offload_engine for Ping-Pong operations
# Get the actual offload_engine for 三区域 operations
offload_engine = kvcache_manager.offload_engine
# Calculate chunk info
ping_size = offload_engine.ping_size
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
# Calculate chunk info using Compute区
compute_size = offload_engine.num_compute_blocks
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size
o_acc = None
lse_acc = None
current_buffer = "ping"
# Prefetch first chunk to Ping buffer (loads all layers at once)
first_chunk_end = min(ping_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
offload_engine.load_to_ping(first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * ping_size
end = min(start + ping_size, len(cpu_block_table))
start = chunk_idx * compute_size
end = min(start + compute_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
chunk_ids = cpu_block_table[start:end]
# Prefetch next chunk to OTHER buffer (overlapped with current computation)
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + ping_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if current_buffer == "ping":
offload_engine.load_to_pong(next_chunk_ids)
else:
offload_engine.load_to_ping(next_chunk_ids)
# Load this chunk to Compute区
# Only layer 0 triggers the load (loads ALL layers at once)
if self.layer_id == 0:
offload_engine.load_to_compute(chunk_ids)
# Wait for current buffer to be ready and get KV
if current_buffer == "ping":
offload_engine.wait_ping()
k_chunk, v_chunk = offload_engine.get_kv_for_ping_slots(
self.layer_id, num_blocks_in_chunk
)
else:
offload_engine.wait_pong()
k_chunk, v_chunk = offload_engine.get_kv_for_pong_slots(
self.layer_id, num_blocks_in_chunk
)
# Wait for Compute区 to be ready and get KV
offload_engine.wait_compute()
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
@@ -286,8 +286,21 @@ class Attention(nn.Module):
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Switch buffer for next iteration
current_buffer = "pong" if current_buffer == "ping" else "ping"
# Now attend to Decode区 (contains the new token's KV)
# This is the token being decoded - only 1 token at position pos_in_block
pos_in_block = context.decode_pos_in_block
decode_k, decode_v = offload_engine.get_kv_for_decode_slot(self.layer_id, pos_in_block)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")

View File

@@ -27,6 +27,8 @@ class Context:
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
# Current sequence being processed (for chunked prefill to load KV)
chunked_seq: Any = None
# Position within block for decode (used for reading from Decode区)
decode_pos_in_block: int = 0
_CONTEXT = Context()
@@ -50,6 +52,7 @@ def set_context(
chunk_offset=0,
offload_engine=None,
chunked_seq=None,
decode_pos_in_block=0,
):
global _CONTEXT
_CONTEXT = Context(
@@ -66,6 +69,7 @@ def set_context(
chunk_offset=chunk_offset,
offload_engine=offload_engine,
chunked_seq=chunked_seq,
decode_pos_in_block=decode_pos_in_block,
)