[fix] Fixed kvcache offload bugs.
This commit is contained in:
@@ -58,12 +58,14 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
||||
from nanovllm.kvcache.policies import get_policy
|
||||
|
||||
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||
num_prefetch_blocks = getattr(config, 'num_prefetch_blocks', 2)
|
||||
|
||||
return HybridKVCacheManager(
|
||||
num_gpu_slots=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=config.kvcache_block_size,
|
||||
policy=policy,
|
||||
num_prefetch_blocks=num_prefetch_blocks,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ Key design for CUDA Graph compatibility:
|
||||
5. Graph replay only needs index updates (tiny overhead)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
@@ -16,6 +17,8 @@ from typing import List, Tuple, Dict, Set, Optional
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.kvcache.base_manager import KVCacheManager
|
||||
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||
@@ -82,6 +85,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block_size: int,
|
||||
policy: Optional[EvictionPolicy] = None,
|
||||
cpu_primary: bool = True,
|
||||
num_prefetch_blocks: int = 2,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid manager.
|
||||
@@ -91,14 +95,16 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU)
|
||||
cpu_primary: If True, use CPU as primary storage with Ping-Pong GPU buffer.
|
||||
cpu_primary: If True, use CPU as primary storage with 三区域 GPU buffer.
|
||||
If False, use GPU as primary with CPU as overflow (legacy mode).
|
||||
num_prefetch_blocks: Number of prefetch blocks for 三区域 GPU buffer design
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||
self.cpu_primary = cpu_primary # Ping-Pong mode flag
|
||||
self.cpu_primary = cpu_primary # 三区域 mode flag
|
||||
self.num_prefetch_blocks = num_prefetch_blocks # 三区域设计参数
|
||||
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
@@ -156,6 +162,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
num_prefetch_blocks=self.num_prefetch_blocks,
|
||||
)
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
@@ -948,6 +955,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_blocks.append(block.cpu_block_id)
|
||||
logger.debug(
|
||||
f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
|
||||
f"returned cpu_blocks={cpu_blocks}"
|
||||
)
|
||||
return cpu_blocks
|
||||
|
||||
def load_prev_kv_for_layer(
|
||||
@@ -1139,28 +1150,18 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
|
||||
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
|
||||
"""
|
||||
获取 Ping-Pong decode 时新 KV 写入的 GPU slot。
|
||||
获取三区域 decode 时新 KV 写入的 GPU slot。
|
||||
|
||||
策略:使用序列所需 chunks 数决定最后用的是 Ping 还是 Pong buffer,
|
||||
然后使用该 buffer 的最后一个 slot。
|
||||
在三区域设计中,永远使用 Decode区 (slot 0) 写入新 KV。
|
||||
这样可以避免与 Compute/Prefetch区 的加载操作冲突。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
GPU slot ID
|
||||
GPU slot ID (永远是 decode_slot = 0)
|
||||
"""
|
||||
cpu_blocks, _ = self.get_all_cpu_blocks(seq)
|
||||
ping_size = self.offload_engine.ping_size
|
||||
num_chunks = (len(cpu_blocks) + ping_size - 1) // ping_size if cpu_blocks else 0
|
||||
|
||||
# 最后一个 chunk 用的是哪个 buffer
|
||||
if num_chunks % 2 == 1 or num_chunks == 0:
|
||||
# 奇数个 chunk(或0个),最后用的是 ping
|
||||
return self.offload_engine.ping_slots[-1]
|
||||
else:
|
||||
# 偶数个 chunk,最后用的是 pong
|
||||
return self.offload_engine.pong_slots[-1]
|
||||
return self.offload_engine.decode_slot
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
||||
@@ -53,6 +53,7 @@ class OffloadEngine:
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_streams: int = 4,
|
||||
num_prefetch_blocks: int = 2,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
@@ -64,31 +65,59 @@ class OffloadEngine:
|
||||
self.kv_dim = num_kv_heads * head_dim
|
||||
self.block_numel = block_size * self.kv_dim
|
||||
|
||||
# ========== Ping-Pong 双缓冲配置 ==========
|
||||
assert num_gpu_blocks >= 2, "Ping-Pong需要至少2个GPU blocks"
|
||||
self.ping_size = num_gpu_blocks // 2
|
||||
self.pong_size = num_gpu_blocks - self.ping_size
|
||||
self.ping_slots = list(range(self.ping_size)) # [0, 1, 2, ...]
|
||||
self.pong_slots = list(range(self.ping_size, num_gpu_blocks)) # [ping_size, ping_size+1, ...]
|
||||
# ========== 三区域 GPU Buffer 配置 ==========
|
||||
# 约束检查
|
||||
assert num_gpu_blocks >= 3, \
|
||||
f"至少需要3个GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
|
||||
assert num_prefetch_blocks >= 1, \
|
||||
f"至少需要1个prefetch block, got {num_prefetch_blocks}"
|
||||
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
|
||||
f"GPU blocks不足: 需要 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
|
||||
|
||||
# 三区域配置
|
||||
# Decode区: [0] - 固定1个block用于写入新KV
|
||||
self.decode_slot = 0
|
||||
|
||||
# Compute区: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
|
||||
compute_start = 1
|
||||
compute_end = num_gpu_blocks - num_prefetch_blocks
|
||||
self.compute_slots = list(range(compute_start, compute_end))
|
||||
self.num_compute_blocks = len(self.compute_slots)
|
||||
|
||||
# Prefetch区: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
|
||||
prefetch_start = compute_end
|
||||
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
|
||||
self.num_prefetch_blocks = num_prefetch_blocks
|
||||
|
||||
self.num_gpu_slots = num_gpu_blocks # alias
|
||||
|
||||
# 保留旧的ping/pong属性以兼容(后续会移除)
|
||||
self.ping_size = self.num_compute_blocks
|
||||
self.pong_size = self.num_prefetch_blocks
|
||||
self.ping_slots = self.compute_slots.copy()
|
||||
self.pong_slots = self.prefetch_slots.copy()
|
||||
|
||||
logger.info(f"三区域 GPU Buffer: decode_slot={self.decode_slot}, "
|
||||
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
|
||||
|
||||
# ========== Fixed-address GPU KV cache ==========
|
||||
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
self.k_cache_gpu = torch.empty(
|
||||
# 使用 zeros 初始化以避免未初始化内存问题
|
||||
self.k_cache_gpu = torch.zeros(
|
||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
self.v_cache_gpu = torch.empty(
|
||||
self.v_cache_gpu = torch.zeros(
|
||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
||||
self.k_cache_cpu = torch.empty(
|
||||
self.k_cache_cpu = torch.zeros(
|
||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cpu", pin_memory=True
|
||||
)
|
||||
self.v_cache_cpu = torch.empty(
|
||||
self.v_cache_cpu = torch.zeros(
|
||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cpu", pin_memory=True
|
||||
)
|
||||
@@ -111,14 +140,18 @@ class OffloadEngine:
|
||||
self.compute_stream = torch.cuda.current_stream()
|
||||
self._stream_idx = 0
|
||||
|
||||
# ========== Ping-Pong 专用 stream 和事件 ==========
|
||||
self.pingpong_stream = torch.cuda.Stream() # 专用于Ping-Pong传输
|
||||
# ========== 三区域专用 stream 和事件 ==========
|
||||
self.transfer_stream_main = torch.cuda.Stream() # 主传输stream
|
||||
|
||||
# 同步事件 - 加载完成
|
||||
self.ping_ready = torch.cuda.Event()
|
||||
self.pong_ready = torch.cuda.Event()
|
||||
# 同步事件 - 三区域加载完成
|
||||
self.compute_ready = torch.cuda.Event()
|
||||
self.prefetch_ready = torch.cuda.Event()
|
||||
self.decode_offload_done = torch.cuda.Event()
|
||||
|
||||
# 同步事件 - offload完成
|
||||
# 保留旧的ping/pong事件以兼容(后续会移除)
|
||||
self.pingpong_stream = self.transfer_stream_main
|
||||
self.ping_ready = self.compute_ready
|
||||
self.pong_ready = self.prefetch_ready
|
||||
self.ping_offload_done = torch.cuda.Event()
|
||||
self.pong_offload_done = torch.cuda.Event()
|
||||
|
||||
@@ -535,7 +568,7 @@ class OffloadEngine:
|
||||
f" kv_heads={self.num_kv_heads},\n"
|
||||
f" head_dim={self.head_dim},\n"
|
||||
f" dtype={self.dtype},\n"
|
||||
f" ping_size={self.ping_size}, pong_size={self.pong_size},\n"
|
||||
f" 三区域: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
|
||||
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
|
||||
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
|
||||
f")"
|
||||
@@ -742,4 +775,189 @@ class OffloadEngine:
|
||||
v = self.v_cache_gpu[layer_id, gpu_slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
return k, v
|
||||
|
||||
# ========== 三区域 GPU Buffer 方法 ==========
|
||||
|
||||
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Compute区。
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.compute_ready.record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
|
||||
logger.debug(f"Compute load: CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.compute_slots[i]
|
||||
# 所有层一起复制
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.compute_ready.record(self.transfer_stream_main)
|
||||
|
||||
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Prefetch区。
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.prefetch_ready.record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
|
||||
logger.debug(f"Prefetch load: CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.prefetch_slots[i]
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.prefetch_ready.record(self.transfer_stream_main)
|
||||
|
||||
def wait_compute(self) -> None:
|
||||
"""等待Compute区加载完成。"""
|
||||
self.compute_stream.wait_event(self.compute_ready)
|
||||
|
||||
def wait_prefetch(self) -> None:
|
||||
"""等待Prefetch区加载完成。"""
|
||||
self.compute_stream.wait_event(self.prefetch_ready)
|
||||
|
||||
def swap_compute_prefetch(self) -> None:
|
||||
"""交换Compute区和Prefetch区的角色。"""
|
||||
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
||||
# 同时更新旧的ping/pong slots以保持兼容
|
||||
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
|
||||
|
||||
def offload_decode_slot(self, cpu_block_id: int) -> None:
|
||||
"""
|
||||
将Decode区的KV offload到CPU。
|
||||
|
||||
Args:
|
||||
cpu_block_id: 目标CPU block ID
|
||||
"""
|
||||
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, self.decode_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, self.decode_slot], non_blocking=True
|
||||
)
|
||||
self.decode_offload_done.record(self.transfer_stream_main)
|
||||
|
||||
def wait_decode_offload(self) -> None:
|
||||
"""等待Decode区offload完成。"""
|
||||
self.compute_stream.wait_event(self.decode_offload_done)
|
||||
|
||||
def get_kv_for_compute(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Compute区中指定数量blocks的KV。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_blocks: 需要的block数量
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.compute_slots[:num_blocks]
|
||||
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
# Reshape: [num_blocks, block_size, heads, dim] -> [1, num_blocks*block_size, heads, dim]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_prefetch(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Prefetch区中指定数量blocks的KV。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_blocks: 需要的block数量
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.prefetch_slots[:num_blocks]
|
||||
k = self.k_cache_gpu[layer_id, slots]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_decode_slot(
|
||||
self,
|
||||
layer_id: int,
|
||||
pos_in_block: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Decode区指定位置的KV(用于decode时的新token)。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
pos_in_block: token在block内的位置 (0 to block_size-1)
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, 1, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
||||
k = k.unsqueeze(0) # [1, 1, heads, dim]
|
||||
v = v.unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
将Compute区的KV offload到CPU。
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 目标CPU block IDs列表
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
return
|
||||
|
||||
num_to_offload = min(len(cpu_block_ids), len(self.compute_slots))
|
||||
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# 等待计算完成
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
|
||||
for i in range(num_to_offload):
|
||||
gpu_slot = self.compute_slots[i]
|
||||
cpu_id = cpu_block_ids[i]
|
||||
self.k_cache_cpu[:, cpu_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
Reference in New Issue
Block a user