[fix] Fixed kvcache offload bugs.
This commit is contained in:
@@ -22,6 +22,7 @@ class Config:
|
||||
offload_policy: str = "lru" # "lru", "fifo", or full class path
|
||||
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
|
||||
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
|
||||
num_prefetch_blocks: int = 2 # Prefetch区的block数量,用于三区域GPU Buffer设计
|
||||
|
||||
# Computed fields for offload (set in __post_init__ or by ModelRunner)
|
||||
num_gpu_kvcache_blocks: int = -1
|
||||
|
||||
@@ -148,19 +148,39 @@ class ModelRunner:
|
||||
dtype=hf_config.torch_dtype,
|
||||
)
|
||||
|
||||
# Log KV cache allocation info
|
||||
# Log KV cache allocation info with detailed per-token breakdown
|
||||
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||
total_memory_mb = gpu_memory_mb + cpu_memory_mb
|
||||
|
||||
# Calculate per-token KV cache usage
|
||||
# KV per token = 2 (K+V) * num_layers * kv_heads * head_dim * dtype_size
|
||||
dtype_size = 2 if hf_config.torch_dtype in [torch.float16, torch.bfloat16] else 4
|
||||
per_token_kv_bytes = 2 * hf_config.num_hidden_layers * num_kv_heads * head_dim * dtype_size
|
||||
per_token_kv_kb = per_token_kv_bytes / 1024
|
||||
|
||||
logger.info(
|
||||
f"KV Cache per-token: {per_token_kv_kb:.2f}KB "
|
||||
f"(2 * {hf_config.num_hidden_layers}layers * {num_kv_heads}kv_heads * {head_dim}head_dim * {dtype_size}bytes)"
|
||||
)
|
||||
logger.info(
|
||||
f"KV Cache per-block: {block_bytes / (1024**2):.2f}MB "
|
||||
f"({per_token_kv_kb:.2f}KB * {self.block_size}tokens)"
|
||||
)
|
||||
|
||||
if config.enable_cpu_offload:
|
||||
ping_size = config.num_gpu_kvcache_blocks // 2
|
||||
tokens_per_ping = ping_size * self.block_size
|
||||
logger.info(
|
||||
f"KV Cache allocated (Ping-Pong mode): "
|
||||
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
|
||||
f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), "
|
||||
f"Total={total_memory_mb:.1f}MB, "
|
||||
f"block_size={self.block_size}, "
|
||||
f"ping_size={config.num_gpu_kvcache_blocks // 2}"
|
||||
f"Total={total_memory_mb:.1f}MB"
|
||||
)
|
||||
logger.info(
|
||||
f"Ping-Pong config: ping_size={ping_size} blocks, "
|
||||
f"tokens_per_chunk={tokens_per_ping}, "
|
||||
f"block_size={self.block_size}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
@@ -392,12 +412,12 @@ class ModelRunner:
|
||||
|
||||
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
|
||||
"""
|
||||
Check if Ping-Pong mode should be used.
|
||||
Check if 三区域 mode should be used.
|
||||
|
||||
Use Ping-Pong when:
|
||||
Use 三区域 when:
|
||||
- CPU offload is enabled
|
||||
- There are blocks on CPU (either allocated there or offloaded)
|
||||
- Sequence exceeds GPU capacity
|
||||
- Sequence exceeds GPU Compute区 capacity
|
||||
"""
|
||||
if not hasattr(self.kvcache_manager, 'offload_engine'):
|
||||
return False
|
||||
@@ -409,12 +429,12 @@ class ModelRunner:
|
||||
# Check if any blocks are on CPU
|
||||
cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if cpu_blocks:
|
||||
# Has CPU blocks - use Ping-Pong
|
||||
# Has CPU blocks - use 三区域
|
||||
return True
|
||||
|
||||
# Check if sequence needs more blocks than GPU can hold
|
||||
ping_size = self.kvcache_manager.offload_engine.ping_size
|
||||
if seq.num_blocks > ping_size:
|
||||
# Check if sequence needs more blocks than GPU Compute区 can hold
|
||||
compute_size = self.kvcache_manager.offload_engine.num_compute_blocks
|
||||
if seq.num_blocks > compute_size:
|
||||
# Needs chunked processing
|
||||
return True
|
||||
|
||||
@@ -610,29 +630,28 @@ class ModelRunner:
|
||||
|
||||
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run prefill with Ping-Pong dual buffer (CPU is primary storage).
|
||||
Run prefill with 三区域 GPU buffer (CPU is primary storage).
|
||||
|
||||
Flow:
|
||||
1. All blocks are allocated to CPU (primary storage)
|
||||
2. Process tokens in chunks using Ping/Pong GPU buffers
|
||||
3. After each chunk, offload from GPU to CPU
|
||||
4. Alternate between Ping and Pong buffers
|
||||
2. Process tokens in chunks using Compute区 GPU buffer
|
||||
3. After each chunk, offload from Compute区 to CPU
|
||||
4. Prefetch区 用于加载 previous KV(如果有的话)
|
||||
"""
|
||||
import sys
|
||||
|
||||
assert len(seqs) == 1, "Ping-Pong prefill only supports single sequence"
|
||||
assert len(seqs) == 1, "三区域 prefill only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
offload_engine = self.kvcache_manager.offload_engine
|
||||
ping_size = offload_engine.ping_size
|
||||
tokens_per_chunk = ping_size * self.block_size
|
||||
compute_size = offload_engine.num_compute_blocks
|
||||
tokens_per_chunk = compute_size * self.block_size
|
||||
|
||||
total_tokens = len(seq)
|
||||
print(f"[Ping-Pong Prefill] Starting: {total_tokens} tokens, "
|
||||
f"ping_size={ping_size} blocks, chunk={tokens_per_chunk} tokens",
|
||||
print(f"[三区域 Prefill] Starting: {total_tokens} tokens, "
|
||||
f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens",
|
||||
file=sys.stderr)
|
||||
|
||||
current_buffer = "ping"
|
||||
chunk_num = 0
|
||||
logits = None
|
||||
processed_tokens = 0
|
||||
@@ -651,15 +670,13 @@ class ModelRunner:
|
||||
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
|
||||
num_blocks = end_block_idx - start_block_idx
|
||||
|
||||
print(f"[Ping-Pong Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"blocks {start_block_idx}-{end_block_idx-1}, buffer={current_buffer}",
|
||||
print(f"[三区域 Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"blocks {start_block_idx}-{end_block_idx-1}, "
|
||||
f"compute_slots={offload_engine.compute_slots[:num_blocks]}",
|
||||
file=sys.stderr)
|
||||
|
||||
# Get GPU slots for this chunk (Ping or Pong buffer)
|
||||
if current_buffer == "ping":
|
||||
gpu_slots = offload_engine.ping_slots[:num_blocks]
|
||||
else:
|
||||
gpu_slots = offload_engine.pong_slots[:num_blocks]
|
||||
# Get GPU slots for this chunk (使用 Compute区)
|
||||
gpu_slots = offload_engine.compute_slots[:num_blocks]
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, positions = self._prepare_pingpong_chunk(
|
||||
@@ -678,24 +695,19 @@ class ModelRunner:
|
||||
logical_id = seq.block_table[i]
|
||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||
|
||||
# Offload this chunk from GPU to CPU (async)
|
||||
# Offload this chunk from Compute区 to CPU (async)
|
||||
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
|
||||
offload_engine.offload_buffer_to_cpu(current_buffer, chunk_cpu_blocks)
|
||||
offload_engine.offload_compute_to_cpu(chunk_cpu_blocks)
|
||||
|
||||
# Switch buffer for next chunk
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping_offload_done()
|
||||
current_buffer = "pong"
|
||||
else:
|
||||
offload_engine.wait_pong_offload_done()
|
||||
current_buffer = "ping"
|
||||
# Wait for offload to complete before next chunk
|
||||
offload_engine.wait_all_offload_done()
|
||||
|
||||
processed_tokens = chunk_end
|
||||
|
||||
# Wait for all offloads to complete
|
||||
offload_engine.wait_all_offload_done()
|
||||
|
||||
print(f"[Ping-Pong Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
|
||||
print(f"[三区域 Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
|
||||
|
||||
# Sample from last logits
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
@@ -764,25 +776,26 @@ class ModelRunner:
|
||||
|
||||
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run decode with Ping-Pong dual buffer.
|
||||
Run decode with 三区域 GPU buffer.
|
||||
|
||||
All KV is on CPU. Uses Ping-Pong to load KV chunks and compute attention.
|
||||
New token's KV is written to GPU then offloaded to CPU.
|
||||
All KV is on CPU. Uses Decode区 to write new KV, Compute/Prefetch区 to load KV chunks.
|
||||
New token's KV is written to Decode区 (slot 0) then offloaded to CPU.
|
||||
|
||||
关键:Decode区 永远不会被 Compute/Prefetch 覆盖,专门用于写入新KV。
|
||||
"""
|
||||
assert len(seqs) == 1, "Ping-Pong decode only supports single sequence"
|
||||
assert len(seqs) == 1, "三区域 decode only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
offload_engine = self.kvcache_manager.offload_engine
|
||||
|
||||
# Prepare inputs
|
||||
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Get write slot for new KV (will use last slot of the buffer used for final chunk)
|
||||
write_slot = self.kvcache_manager.get_write_slot_for_pingpong(seq)
|
||||
|
||||
# Calculate position in block for slot mapping
|
||||
last_block_idx = seq.num_blocks - 1
|
||||
# 使用 Decode区 (slot 0) 写入新 KV
|
||||
decode_slot = offload_engine.decode_slot # = 0
|
||||
pos_in_block = (len(seq) - 1) % self.block_size
|
||||
slot = write_slot * self.block_size + pos_in_block
|
||||
slot = decode_slot * self.block_size + pos_in_block
|
||||
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
@@ -794,17 +807,18 @@ class ModelRunner:
|
||||
is_chunked_prefill=True, # Use chunked attention path
|
||||
offload_engine=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
decode_pos_in_block=pos_in_block,
|
||||
)
|
||||
|
||||
# Run model forward pass
|
||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||
reset_context()
|
||||
|
||||
# Offload new KV from write_slot to CPU
|
||||
# Offload new KV from Decode区 to CPU
|
||||
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
|
||||
if last_cpu_block >= 0:
|
||||
self.kvcache_manager.offload_engine.offload_slot_to_cpu(write_slot, last_cpu_block)
|
||||
self.kvcache_manager.offload_engine.wait_all_offload_done()
|
||||
offload_engine.offload_decode_slot(last_cpu_block)
|
||||
offload_engine.wait_all_offload_done()
|
||||
|
||||
# Sample
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
|
||||
@@ -58,12 +58,14 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
||||
from nanovllm.kvcache.policies import get_policy
|
||||
|
||||
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||
num_prefetch_blocks = getattr(config, 'num_prefetch_blocks', 2)
|
||||
|
||||
return HybridKVCacheManager(
|
||||
num_gpu_slots=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=config.kvcache_block_size,
|
||||
policy=policy,
|
||||
num_prefetch_blocks=num_prefetch_blocks,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ Key design for CUDA Graph compatibility:
|
||||
5. Graph replay only needs index updates (tiny overhead)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
@@ -16,6 +17,8 @@ from typing import List, Tuple, Dict, Set, Optional
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.kvcache.base_manager import KVCacheManager
|
||||
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||
@@ -82,6 +85,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block_size: int,
|
||||
policy: Optional[EvictionPolicy] = None,
|
||||
cpu_primary: bool = True,
|
||||
num_prefetch_blocks: int = 2,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid manager.
|
||||
@@ -91,14 +95,16 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU)
|
||||
cpu_primary: If True, use CPU as primary storage with Ping-Pong GPU buffer.
|
||||
cpu_primary: If True, use CPU as primary storage with 三区域 GPU buffer.
|
||||
If False, use GPU as primary with CPU as overflow (legacy mode).
|
||||
num_prefetch_blocks: Number of prefetch blocks for 三区域 GPU buffer design
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||
self.cpu_primary = cpu_primary # Ping-Pong mode flag
|
||||
self.cpu_primary = cpu_primary # 三区域 mode flag
|
||||
self.num_prefetch_blocks = num_prefetch_blocks # 三区域设计参数
|
||||
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
@@ -156,6 +162,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
num_prefetch_blocks=self.num_prefetch_blocks,
|
||||
)
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
@@ -948,6 +955,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_blocks.append(block.cpu_block_id)
|
||||
logger.debug(
|
||||
f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
|
||||
f"returned cpu_blocks={cpu_blocks}"
|
||||
)
|
||||
return cpu_blocks
|
||||
|
||||
def load_prev_kv_for_layer(
|
||||
@@ -1139,28 +1150,18 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
|
||||
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
|
||||
"""
|
||||
获取 Ping-Pong decode 时新 KV 写入的 GPU slot。
|
||||
获取三区域 decode 时新 KV 写入的 GPU slot。
|
||||
|
||||
策略:使用序列所需 chunks 数决定最后用的是 Ping 还是 Pong buffer,
|
||||
然后使用该 buffer 的最后一个 slot。
|
||||
在三区域设计中,永远使用 Decode区 (slot 0) 写入新 KV。
|
||||
这样可以避免与 Compute/Prefetch区 的加载操作冲突。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
GPU slot ID
|
||||
GPU slot ID (永远是 decode_slot = 0)
|
||||
"""
|
||||
cpu_blocks, _ = self.get_all_cpu_blocks(seq)
|
||||
ping_size = self.offload_engine.ping_size
|
||||
num_chunks = (len(cpu_blocks) + ping_size - 1) // ping_size if cpu_blocks else 0
|
||||
|
||||
# 最后一个 chunk 用的是哪个 buffer
|
||||
if num_chunks % 2 == 1 or num_chunks == 0:
|
||||
# 奇数个 chunk(或0个),最后用的是 ping
|
||||
return self.offload_engine.ping_slots[-1]
|
||||
else:
|
||||
# 偶数个 chunk,最后用的是 pong
|
||||
return self.offload_engine.pong_slots[-1]
|
||||
return self.offload_engine.decode_slot
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
||||
@@ -53,6 +53,7 @@ class OffloadEngine:
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_streams: int = 4,
|
||||
num_prefetch_blocks: int = 2,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
@@ -64,31 +65,59 @@ class OffloadEngine:
|
||||
self.kv_dim = num_kv_heads * head_dim
|
||||
self.block_numel = block_size * self.kv_dim
|
||||
|
||||
# ========== Ping-Pong 双缓冲配置 ==========
|
||||
assert num_gpu_blocks >= 2, "Ping-Pong需要至少2个GPU blocks"
|
||||
self.ping_size = num_gpu_blocks // 2
|
||||
self.pong_size = num_gpu_blocks - self.ping_size
|
||||
self.ping_slots = list(range(self.ping_size)) # [0, 1, 2, ...]
|
||||
self.pong_slots = list(range(self.ping_size, num_gpu_blocks)) # [ping_size, ping_size+1, ...]
|
||||
# ========== 三区域 GPU Buffer 配置 ==========
|
||||
# 约束检查
|
||||
assert num_gpu_blocks >= 3, \
|
||||
f"至少需要3个GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
|
||||
assert num_prefetch_blocks >= 1, \
|
||||
f"至少需要1个prefetch block, got {num_prefetch_blocks}"
|
||||
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
|
||||
f"GPU blocks不足: 需要 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
|
||||
|
||||
# 三区域配置
|
||||
# Decode区: [0] - 固定1个block用于写入新KV
|
||||
self.decode_slot = 0
|
||||
|
||||
# Compute区: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
|
||||
compute_start = 1
|
||||
compute_end = num_gpu_blocks - num_prefetch_blocks
|
||||
self.compute_slots = list(range(compute_start, compute_end))
|
||||
self.num_compute_blocks = len(self.compute_slots)
|
||||
|
||||
# Prefetch区: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
|
||||
prefetch_start = compute_end
|
||||
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
|
||||
self.num_prefetch_blocks = num_prefetch_blocks
|
||||
|
||||
self.num_gpu_slots = num_gpu_blocks # alias
|
||||
|
||||
# 保留旧的ping/pong属性以兼容(后续会移除)
|
||||
self.ping_size = self.num_compute_blocks
|
||||
self.pong_size = self.num_prefetch_blocks
|
||||
self.ping_slots = self.compute_slots.copy()
|
||||
self.pong_slots = self.prefetch_slots.copy()
|
||||
|
||||
logger.info(f"三区域 GPU Buffer: decode_slot={self.decode_slot}, "
|
||||
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
|
||||
|
||||
# ========== Fixed-address GPU KV cache ==========
|
||||
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
self.k_cache_gpu = torch.empty(
|
||||
# 使用 zeros 初始化以避免未初始化内存问题
|
||||
self.k_cache_gpu = torch.zeros(
|
||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
self.v_cache_gpu = torch.empty(
|
||||
self.v_cache_gpu = torch.zeros(
|
||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
||||
self.k_cache_cpu = torch.empty(
|
||||
self.k_cache_cpu = torch.zeros(
|
||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cpu", pin_memory=True
|
||||
)
|
||||
self.v_cache_cpu = torch.empty(
|
||||
self.v_cache_cpu = torch.zeros(
|
||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cpu", pin_memory=True
|
||||
)
|
||||
@@ -111,14 +140,18 @@ class OffloadEngine:
|
||||
self.compute_stream = torch.cuda.current_stream()
|
||||
self._stream_idx = 0
|
||||
|
||||
# ========== Ping-Pong 专用 stream 和事件 ==========
|
||||
self.pingpong_stream = torch.cuda.Stream() # 专用于Ping-Pong传输
|
||||
# ========== 三区域专用 stream 和事件 ==========
|
||||
self.transfer_stream_main = torch.cuda.Stream() # 主传输stream
|
||||
|
||||
# 同步事件 - 加载完成
|
||||
self.ping_ready = torch.cuda.Event()
|
||||
self.pong_ready = torch.cuda.Event()
|
||||
# 同步事件 - 三区域加载完成
|
||||
self.compute_ready = torch.cuda.Event()
|
||||
self.prefetch_ready = torch.cuda.Event()
|
||||
self.decode_offload_done = torch.cuda.Event()
|
||||
|
||||
# 同步事件 - offload完成
|
||||
# 保留旧的ping/pong事件以兼容(后续会移除)
|
||||
self.pingpong_stream = self.transfer_stream_main
|
||||
self.ping_ready = self.compute_ready
|
||||
self.pong_ready = self.prefetch_ready
|
||||
self.ping_offload_done = torch.cuda.Event()
|
||||
self.pong_offload_done = torch.cuda.Event()
|
||||
|
||||
@@ -535,7 +568,7 @@ class OffloadEngine:
|
||||
f" kv_heads={self.num_kv_heads},\n"
|
||||
f" head_dim={self.head_dim},\n"
|
||||
f" dtype={self.dtype},\n"
|
||||
f" ping_size={self.ping_size}, pong_size={self.pong_size},\n"
|
||||
f" 三区域: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
|
||||
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
|
||||
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
|
||||
f")"
|
||||
@@ -742,4 +775,189 @@ class OffloadEngine:
|
||||
v = self.v_cache_gpu[layer_id, gpu_slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
return k, v
|
||||
|
||||
# ========== 三区域 GPU Buffer 方法 ==========
|
||||
|
||||
def load_to_compute(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Compute区。
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.compute_ready.record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.compute_slots))
|
||||
logger.debug(f"Compute load: CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.compute_slots[i]
|
||||
# 所有层一起复制
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.compute_ready.record(self.transfer_stream_main)
|
||||
|
||||
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Prefetch区。
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.prefetch_ready.record(self.transfer_stream_main)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots))
|
||||
logger.debug(f"Prefetch load: CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.prefetch_slots[i]
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.prefetch_ready.record(self.transfer_stream_main)
|
||||
|
||||
def wait_compute(self) -> None:
|
||||
"""等待Compute区加载完成。"""
|
||||
self.compute_stream.wait_event(self.compute_ready)
|
||||
|
||||
def wait_prefetch(self) -> None:
|
||||
"""等待Prefetch区加载完成。"""
|
||||
self.compute_stream.wait_event(self.prefetch_ready)
|
||||
|
||||
def swap_compute_prefetch(self) -> None:
|
||||
"""交换Compute区和Prefetch区的角色。"""
|
||||
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
|
||||
# 同时更新旧的ping/pong slots以保持兼容
|
||||
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
|
||||
|
||||
def offload_decode_slot(self, cpu_block_id: int) -> None:
|
||||
"""
|
||||
将Decode区的KV offload到CPU。
|
||||
|
||||
Args:
|
||||
cpu_block_id: 目标CPU block ID
|
||||
"""
|
||||
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, self.decode_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, self.decode_slot], non_blocking=True
|
||||
)
|
||||
self.decode_offload_done.record(self.transfer_stream_main)
|
||||
|
||||
def wait_decode_offload(self) -> None:
|
||||
"""等待Decode区offload完成。"""
|
||||
self.compute_stream.wait_event(self.decode_offload_done)
|
||||
|
||||
def get_kv_for_compute(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Compute区中指定数量blocks的KV。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_blocks: 需要的block数量
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.compute_slots[:num_blocks]
|
||||
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
# Reshape: [num_blocks, block_size, heads, dim] -> [1, num_blocks*block_size, heads, dim]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_prefetch(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_blocks: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Prefetch区中指定数量blocks的KV。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_blocks: 需要的block数量
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.prefetch_slots[:num_blocks]
|
||||
k = self.k_cache_gpu[layer_id, slots]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_decode_slot(
|
||||
self,
|
||||
layer_id: int,
|
||||
pos_in_block: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Decode区指定位置的KV(用于decode时的新token)。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
pos_in_block: token在block内的位置 (0 to block_size-1)
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, 1, kv_heads, head_dim]
|
||||
"""
|
||||
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
|
||||
k = k.unsqueeze(0) # [1, 1, heads, dim]
|
||||
v = v.unsqueeze(0)
|
||||
return k, v
|
||||
|
||||
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
将Compute区的KV offload到CPU。
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 目标CPU block IDs列表
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
return
|
||||
|
||||
num_to_offload = min(len(cpu_block_ids), len(self.compute_slots))
|
||||
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# 等待计算完成
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
|
||||
for i in range(num_to_offload):
|
||||
gpu_slot = self.compute_slots[i]
|
||||
cpu_id = cpu_block_ids[i]
|
||||
self.k_cache_cpu[:, cpu_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import torch
|
||||
from torch import nn
|
||||
import triton
|
||||
@@ -6,6 +7,8 @@ import triton.language as tl
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def store_kvcache_kernel(
|
||||
@@ -97,13 +100,16 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with Ping-Pong dual buffer for chunked prefill.
|
||||
Compute attention with 三区域 GPU buffer for chunked prefill.
|
||||
|
||||
For chunked prefill:
|
||||
1. Load previous KV from CPU using Ping-Pong (if any previous chunks)
|
||||
1. Load previous KV from CPU using Compute/Prefetch区 (if any previous chunks)
|
||||
2. Compute attention against previous KV chunks (no causal mask)
|
||||
3. Compute attention against current chunk's KV (causal)
|
||||
4. Merge all results using online softmax
|
||||
|
||||
三区域设计保证:当前chunk的KV在Compute区,previous KV从CPU加载到Prefetch区,
|
||||
不会发生写入和加载区域重叠的问题。
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
@@ -116,7 +122,7 @@ class Attention(nn.Module):
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
|
||||
# Load previous KV from CPU using Ping-Pong
|
||||
# Load previous KV from CPU using Compute/Prefetch区
|
||||
# Note: context.offload_engine is actually HybridKVCacheManager
|
||||
kvcache_manager = context.offload_engine
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
@@ -127,41 +133,42 @@ class Attention(nn.Module):
|
||||
|
||||
if cpu_block_table:
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
ping_size = offload_engine.ping_size
|
||||
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
|
||||
current_buffer = "ping"
|
||||
# 使用 Prefetch区 来加载 previous KV(不会与当前 Compute区 冲突)
|
||||
prefetch_size = offload_engine.num_prefetch_blocks
|
||||
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
|
||||
use_compute = True # 交替使用 Compute区 和 Prefetch区
|
||||
|
||||
# Prefetch first chunk to Ping buffer
|
||||
first_chunk_end = min(ping_size, len(cpu_block_table))
|
||||
# 首先将 previous KV 加载到 Prefetch区
|
||||
# Only layer 0 triggers the load (loads ALL layers at once)
|
||||
first_chunk_end = min(prefetch_size, len(cpu_block_table))
|
||||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
||||
offload_engine.load_to_ping(first_chunk_ids)
|
||||
if self.layer_id == 0:
|
||||
offload_engine.load_to_prefetch(first_chunk_ids)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * ping_size
|
||||
end = min(start + ping_size, len(cpu_block_table))
|
||||
start = chunk_idx * prefetch_size
|
||||
end = min(start + prefetch_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
|
||||
# Prefetch next chunk to OTHER buffer
|
||||
if chunk_idx + 1 < num_chunks:
|
||||
# Prefetch next chunk to other buffer (if exists)
|
||||
# Only layer 0 triggers the load
|
||||
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
||||
next_start = end
|
||||
next_end = min(next_start + ping_size, len(cpu_block_table))
|
||||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
if current_buffer == "ping":
|
||||
offload_engine.load_to_pong(next_chunk_ids)
|
||||
if use_compute:
|
||||
# 当前在 Prefetch区,下一个加载到 Compute区(如果有空间)
|
||||
# 注意:Compute区 此时已写入当前chunk的KV,不能覆盖
|
||||
# 所以这里我们使用简单的同步策略:等待当前完成后再加载
|
||||
pass # 简化版本:不进行双缓冲,只用 Prefetch区
|
||||
else:
|
||||
offload_engine.load_to_ping(next_chunk_ids)
|
||||
offload_engine.load_to_prefetch(next_chunk_ids)
|
||||
|
||||
# Wait for current buffer and get KV
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping()
|
||||
prev_k, prev_v = offload_engine.get_kv_for_ping_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
else:
|
||||
offload_engine.wait_pong()
|
||||
prev_k, prev_v = offload_engine.get_kv_for_pong_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
# Wait for Prefetch区 and get KV
|
||||
offload_engine.wait_prefetch()
|
||||
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
|
||||
# Compute attention against this chunk (no causal mask)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
@@ -178,8 +185,12 @@ class Attention(nn.Module):
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
# Switch buffer
|
||||
current_buffer = "pong" if current_buffer == "ping" else "ping"
|
||||
# Load next chunk to Prefetch区 (if exists)
|
||||
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
||||
next_start = end
|
||||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
offload_engine.load_to_prefetch(next_chunk_ids)
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
@@ -207,13 +218,16 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute decode attention with Ping-Pong dual buffer.
|
||||
Compute decode attention with 三区域 GPU buffer.
|
||||
|
||||
All KV is stored on CPU. Uses Ping-Pong buffers on GPU:
|
||||
1. Load first chunk to Ping buffer
|
||||
2. While computing on current buffer, prefetch next chunk to other buffer
|
||||
3. Alternate between Ping and Pong buffers
|
||||
4. Merge attention outputs using online softmax (LSE)
|
||||
All KV is stored on CPU. Uses Compute区 buffer on GPU:
|
||||
1. Load chunk to Compute区
|
||||
2. Compute attention
|
||||
3. Repeat for all chunks
|
||||
4. Finally, attend to Decode区 (slot 0) which contains the new token's KV
|
||||
5. Merge all attention outputs using online softmax (LSE)
|
||||
|
||||
关键:新token的KV在Decode区(slot 0),不会被Compute区的加载覆盖。
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
@@ -227,51 +241,37 @@ class Attention(nn.Module):
|
||||
|
||||
# Get all CPU blocks for this sequence
|
||||
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if self.layer_id == 0:
|
||||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
||||
|
||||
# Get the actual offload_engine for Ping-Pong operations
|
||||
# Get the actual offload_engine for 三区域 operations
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
|
||||
# Calculate chunk info
|
||||
ping_size = offload_engine.ping_size
|
||||
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
|
||||
# Calculate chunk info using Compute区
|
||||
compute_size = offload_engine.num_compute_blocks
|
||||
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size
|
||||
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
current_buffer = "ping"
|
||||
|
||||
# Prefetch first chunk to Ping buffer (loads all layers at once)
|
||||
first_chunk_end = min(ping_size, len(cpu_block_table))
|
||||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
||||
offload_engine.load_to_ping(first_chunk_ids)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * ping_size
|
||||
end = min(start + ping_size, len(cpu_block_table))
|
||||
start = chunk_idx * compute_size
|
||||
end = min(start + compute_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
chunk_ids = cpu_block_table[start:end]
|
||||
|
||||
# Prefetch next chunk to OTHER buffer (overlapped with current computation)
|
||||
if chunk_idx + 1 < num_chunks:
|
||||
next_start = end
|
||||
next_end = min(next_start + ping_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
if current_buffer == "ping":
|
||||
offload_engine.load_to_pong(next_chunk_ids)
|
||||
else:
|
||||
offload_engine.load_to_ping(next_chunk_ids)
|
||||
# Load this chunk to Compute区
|
||||
# Only layer 0 triggers the load (loads ALL layers at once)
|
||||
if self.layer_id == 0:
|
||||
offload_engine.load_to_compute(chunk_ids)
|
||||
|
||||
# Wait for current buffer to be ready and get KV
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping()
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_ping_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
else:
|
||||
offload_engine.wait_pong()
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_pong_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
# Wait for Compute区 to be ready and get KV
|
||||
offload_engine.wait_compute()
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
@@ -286,8 +286,21 @@ class Attention(nn.Module):
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||
|
||||
# Switch buffer for next iteration
|
||||
current_buffer = "pong" if current_buffer == "ping" else "ping"
|
||||
# Now attend to Decode区 (contains the new token's KV)
|
||||
# This is the token being decoded - only 1 token at position pos_in_block
|
||||
pos_in_block = context.decode_pos_in_block
|
||||
decode_k, decode_v = offload_engine.get_kv_for_decode_slot(self.layer_id, pos_in_block)
|
||||
decode_o, decode_lse = flash_attn_with_lse(
|
||||
q_batched, decode_k, decode_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc = decode_o
|
||||
else:
|
||||
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||||
|
||||
if o_acc is None:
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
@@ -27,6 +27,8 @@ class Context:
|
||||
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
|
||||
# Current sequence being processed (for chunked prefill to load KV)
|
||||
chunked_seq: Any = None
|
||||
# Position within block for decode (used for reading from Decode区)
|
||||
decode_pos_in_block: int = 0
|
||||
|
||||
|
||||
_CONTEXT = Context()
|
||||
@@ -50,6 +52,7 @@ def set_context(
|
||||
chunk_offset=0,
|
||||
offload_engine=None,
|
||||
chunked_seq=None,
|
||||
decode_pos_in_block=0,
|
||||
):
|
||||
global _CONTEXT
|
||||
_CONTEXT = Context(
|
||||
@@ -66,6 +69,7 @@ def set_context(
|
||||
chunk_offset=chunk_offset,
|
||||
offload_engine=offload_engine,
|
||||
chunked_seq=chunked_seq,
|
||||
decode_pos_in_block=decode_pos_in_block,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user