[refactor] Refactor current gpu and cpu block allocation strategy.

This commit is contained in:
Zijie Tian
2025-12-10 21:23:31 +08:00
parent 0a247ccb1b
commit 190df5f70d
7 changed files with 906 additions and 162 deletions

View File

@@ -19,7 +19,6 @@ class Config:
# CPU Offload configuration
enable_cpu_offload: bool = False
cpu_memory_gb: float = 16.0 # CPU memory limit for KV cache
offload_policy: str = "lru" # "lru", "fifo", or full class path
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)

View File

@@ -10,8 +10,11 @@ from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.layers.sampler import Sampler
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.loader import load_model
from nanovllm.utils.logger import get_logger
from nanovllm.kvcache import create_kvcache_manager, KVCacheManager
logger = get_logger("model_runner")
class ModelRunner:
@@ -120,9 +123,11 @@ class ModelRunner:
num_gpu_blocks = max_gpu_blocks
if config.enable_cpu_offload:
# Calculate CPU blocks based on cpu_memory_gb
cpu_bytes = int(config.cpu_memory_gb * 1024**3)
num_cpu_blocks = cpu_bytes // block_bytes
# Ping-Pong设计CPU是主存储GPU是工作缓冲区
# CPU blocks = 支持max_model_len所需的全部blocks存储一个最大序列的完整KV
# GPU blocks = Ping-Pong工作缓冲区用户指定或自动
num_cpu_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
config.num_gpu_kvcache_blocks = num_gpu_blocks
config.num_cpu_kvcache_blocks = num_cpu_blocks
# For backward compatibility
@@ -143,6 +148,27 @@ class ModelRunner:
dtype=hf_config.torch_dtype,
)
# Log KV cache allocation info
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
total_memory_mb = gpu_memory_mb + cpu_memory_mb
if config.enable_cpu_offload:
logger.info(
f"KV Cache allocated (Ping-Pong mode): "
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), "
f"Total={total_memory_mb:.1f}MB, "
f"block_size={self.block_size}, "
f"ping_size={config.num_gpu_kvcache_blocks // 2}"
)
else:
logger.info(
f"KV Cache allocated: "
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
f"block_size={self.block_size}"
)
# Bind layer caches to attention modules and set layer_id
layer_id = 0
for module in self.model.modules():
@@ -328,7 +354,16 @@ class ModelRunner:
return self.model.compute_logits(graph_vars["outputs"][:bs])
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
# Check if chunked prefill is needed
# Check if Ping-Pong mode should be used (all blocks on CPU)
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
use_pingpong = self._should_use_pingpong(seqs, is_prefill)
if use_pingpong:
if is_prefill:
return self.run_pingpong_prefill(seqs)
else:
return self.run_pingpong_decode(seqs)
# Check if chunked prefill is needed (legacy path)
if is_prefill and hasattr(self, 'kvcache_manager'):
needs_chunked = any(
hasattr(self.kvcache_manager, 'needs_chunked_prefill') and
@@ -338,7 +373,7 @@ class ModelRunner:
if needs_chunked:
return self.run_chunked_prefill(seqs)
# Check if chunked decode is needed
# Check if chunked decode is needed (legacy path)
if not is_prefill and hasattr(self, 'kvcache_manager'):
needs_chunked = any(
hasattr(self.kvcache_manager, 'needs_chunked_decode') and
@@ -355,6 +390,36 @@ class ModelRunner:
reset_context()
return token_ids
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
"""
Check if Ping-Pong mode should be used.
Use Ping-Pong when:
- CPU offload is enabled
- There are blocks on CPU (either allocated there or offloaded)
- Sequence exceeds GPU capacity
"""
if not hasattr(self.kvcache_manager, 'offload_engine'):
return False
for seq in seqs:
if not seq.block_table:
continue # Skip warmup sequences
# Check if any blocks are on CPU
cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq)
if cpu_blocks:
# Has CPU blocks - use Ping-Pong
return True
# Check if sequence needs more blocks than GPU can hold
ping_size = self.kvcache_manager.offload_engine.ping_size
if seq.num_blocks > ping_size:
# Needs chunked processing
return True
return False
def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill in chunks when sequences exceed GPU capacity.
@@ -543,6 +608,210 @@ class ModelRunner:
return input_ids, positions
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill with Ping-Pong dual buffer (CPU is primary storage).
Flow:
1. All blocks are allocated to CPU (primary storage)
2. Process tokens in chunks using Ping/Pong GPU buffers
3. After each chunk, offload from GPU to CPU
4. Alternate between Ping and Pong buffers
"""
import sys
assert len(seqs) == 1, "Ping-Pong prefill only supports single sequence"
seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine
ping_size = offload_engine.ping_size
tokens_per_chunk = ping_size * self.block_size
total_tokens = len(seq)
print(f"[Ping-Pong Prefill] Starting: {total_tokens} tokens, "
f"ping_size={ping_size} blocks, chunk={tokens_per_chunk} tokens",
file=sys.stderr)
current_buffer = "ping"
chunk_num = 0
logits = None
processed_tokens = 0
# Get CPU block table for offload targets
cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq)
while processed_tokens < total_tokens:
chunk_num += 1
chunk_start = processed_tokens
chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens)
chunk_tokens = chunk_end - chunk_start
# Calculate which CPU blocks this chunk covers
start_block_idx = chunk_start // self.block_size
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
num_blocks = end_block_idx - start_block_idx
print(f"[Ping-Pong Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
f"blocks {start_block_idx}-{end_block_idx-1}, buffer={current_buffer}",
file=sys.stderr)
# Get GPU slots for this chunk (Ping or Pong buffer)
if current_buffer == "ping":
gpu_slots = offload_engine.ping_slots[:num_blocks]
else:
gpu_slots = offload_engine.pong_slots[:num_blocks]
# Prepare inputs
input_ids, positions = self._prepare_pingpong_chunk(
seq, chunk_start, chunk_end, gpu_slots, start_block_idx
)
if input_ids.numel() == 0:
break
# Run model forward
logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context()
# Mark blocks as prefilled
for i in range(start_block_idx, min(end_block_idx, len(seq.block_table))):
logical_id = seq.block_table[i]
self.kvcache_manager.prefilled_blocks.add(logical_id)
# Offload this chunk from GPU to CPU (async)
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
offload_engine.offload_buffer_to_cpu(current_buffer, chunk_cpu_blocks)
# Switch buffer for next chunk
if current_buffer == "ping":
offload_engine.wait_ping_offload_done()
current_buffer = "pong"
else:
offload_engine.wait_pong_offload_done()
current_buffer = "ping"
processed_tokens = chunk_end
# Wait for all offloads to complete
offload_engine.wait_all_offload_done()
print(f"[Ping-Pong Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
# Sample from last logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
if logits is not None:
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
else:
token_ids = [0] if self.rank == 0 else None
return token_ids
def _prepare_pingpong_chunk(
self,
seq: Sequence,
chunk_start: int,
chunk_end: int,
gpu_slots: list[int],
start_block_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Prepare inputs for a Ping-Pong prefill chunk."""
# Input tokens for this chunk
input_ids = seq[chunk_start:chunk_end]
positions = list(range(chunk_start, chunk_end))
# Create slot mapping pointing to GPU slots
slot_mapping = []
num_tokens = chunk_end - chunk_start
token_idx = 0
for i, gpu_slot in enumerate(gpu_slots):
block_idx = start_block_idx + i
block_start = block_idx * self.block_size
block_end = min(block_start + self.block_size, len(seq))
# How many tokens in this block for this chunk
overlap_start = max(chunk_start, block_start)
overlap_end = min(chunk_end, block_end)
for pos in range(overlap_start, overlap_end):
pos_in_block = pos % self.block_size
slot = gpu_slot * self.block_size + pos_in_block
slot_mapping.append(slot)
# Convert to tensors
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Set up context for chunked prefill
seqlen = num_tokens
cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=seqlen,
max_seqlen_k=seqlen,
slot_mapping=slot_mapping,
is_chunked_prefill=True,
offload_engine=self.kvcache_manager,
chunked_seq=seq,
)
return input_ids, positions
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with Ping-Pong dual buffer.
All KV is on CPU. Uses Ping-Pong to load KV chunks and compute attention.
New token's KV is written to GPU then offloaded to CPU.
"""
assert len(seqs) == 1, "Ping-Pong decode only supports single sequence"
seq = seqs[0]
# Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
# Get write slot for new KV (will use last slot of the buffer used for final chunk)
write_slot = self.kvcache_manager.get_write_slot_for_pingpong(seq)
# Calculate position in block for slot mapping
last_block_idx = seq.num_blocks - 1
pos_in_block = (len(seq) - 1) % self.block_size
slot = write_slot * self.block_size + pos_in_block
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Set up context for chunked decode
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention path
offload_engine=self.kvcache_manager,
chunked_seq=seq,
)
# Run model forward pass
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()
# Offload new KV from write_slot to CPU
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
if last_cpu_block >= 0:
self.kvcache_manager.offload_engine.offload_slot_to_cpu(write_slot, last_cpu_block)
self.kvcache_manager.offload_engine.wait_all_offload_done()
# Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
return token_ids
@torch.inference_mode()
def capture_cudagraph(self):
config = self.config

View File

@@ -81,20 +81,24 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: int,
block_size: int,
policy: Optional[EvictionPolicy] = None,
cpu_primary: bool = True,
):
"""
Initialize hybrid manager.
Args:
num_gpu_slots: Number of GPU buffer slots (working set)
num_cpu_blocks: Number of CPU pool blocks (overflow)
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
block_size: Tokens per block
policy: Eviction policy (default: LRU)
cpu_primary: If True, use CPU as primary storage with Ping-Pong GPU buffer.
If False, use GPU as primary with CPU as overflow (legacy mode).
"""
self._block_size = block_size
self.num_gpu_slots = num_gpu_slots
self.num_cpu_blocks = num_cpu_blocks
self.total_blocks = num_gpu_slots + num_cpu_blocks
self.cpu_primary = cpu_primary # Ping-Pong mode flag
# Eviction policy
self.policy = policy or LRUPolicy()
@@ -321,12 +325,16 @@ class HybridKVCacheManager(KVCacheManager):
"""
Allocate logical blocks for prefill.
New blocks are allocated on GPU when possible. If GPU is full and all
GPU blocks belong to this sequence (can't evict), remaining blocks
are allocated to CPU for chunked prefill.
In cpu_primary mode (Ping-Pong): All blocks are allocated to CPU.
In legacy mode: Blocks are allocated to GPU when possible, overflow to CPU.
"""
assert not seq.block_table, "Sequence already has blocks"
# Ping-Pong模式所有blocks都分配到CPU
if self.cpu_primary:
return self.allocate_cpu_only(seq)
# Legacy模式GPU为主CPU为overflow
h = -1
cache_miss = False
@@ -451,13 +459,22 @@ class HybridKVCacheManager(KVCacheManager):
block.hash = -1
block.token_ids = []
# New decode blocks go to GPU
gpu_slot = self._allocate_gpu_slot()
block.location = BlockLocation.GPU
block.gpu_slot = gpu_slot
self.gpu_slot_to_logical[gpu_slot] = logical_id
self.policy.on_block_allocated(gpu_slot, self.current_step)
if self.cpu_primary:
# Ping-Pong模式新block分配到CPU
if not self.free_cpu_blocks:
raise RuntimeError("No free CPU blocks for decode")
cpu_block_id = self.free_cpu_blocks.popleft()
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
self.cpu_block_to_logical[cpu_block_id] = logical_id
else:
# Legacy模式新block分配到GPU
gpu_slot = self._allocate_gpu_slot()
block.location = BlockLocation.GPU
block.gpu_slot = gpu_slot
self.gpu_slot_to_logical[gpu_slot] = logical_id
self.policy.on_block_allocated(gpu_slot, self.current_step)
block_table.append(logical_id)
@@ -993,6 +1010,158 @@ class HybridKVCacheManager(KVCacheManager):
break
return pos
# ========== Ping-Pong 双缓冲支持 ==========
def allocate_cpu_only(self, seq: Sequence) -> None:
"""
为序列分配 CPU blocks用于 Ping-Pong 模式)。
与 allocate() 不同,这里所有 blocks 都分配到 CPU
GPU 只用作工作缓冲区。
Args:
seq: 要分配的序列
"""
assert not seq.block_table, "Sequence already has blocks"
for i in range(seq.num_blocks):
# 分配 CPU block
if not self.free_cpu_blocks:
raise RuntimeError(
f"No free CPU blocks. Need {seq.num_blocks}, "
f"available: {len(self.free_cpu_blocks)}"
)
cpu_block_id = self.free_cpu_blocks.popleft()
# 分配逻辑块
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
self.cpu_block_to_logical[cpu_block_id] = logical_id
seq.block_table.append(logical_id)
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
"""
获取序列的 CPU block ID 列表。
Args:
seq: 序列
Returns:
CPU block IDs 列表,按序列顺序
"""
cpu_blocks = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id)
else:
# 如果 block 在 GPU 上,它应该有一个对应的 CPU block
# 在 Ping-Pong 模式下,所有数据最终都在 CPU 上
raise RuntimeError(
f"Block {logical_id} not on CPU (location={block.location}). "
f"In Ping-Pong mode, all blocks should be on CPU."
)
return cpu_blocks
def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]:
"""
获取序列的所有 CPU blocks 及其逻辑 ID。
Args:
seq: 序列
Returns:
(cpu_block_ids, logical_ids)
"""
cpu_block_ids = []
logical_ids = []
for logical_id in seq.block_table:
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_block_ids.append(block.cpu_block_id)
logical_ids.append(logical_id)
return cpu_block_ids, logical_ids
def allocate_next_cpu_block(self, seq: Sequence) -> int:
"""
为序列分配下一个 CPU block用于 decode 时新 token
Args:
seq: 序列
Returns:
新分配的 CPU block ID
"""
if not self.free_cpu_blocks:
raise RuntimeError("No free CPU blocks")
cpu_block_id = self.free_cpu_blocks.popleft()
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
self.cpu_block_to_logical[cpu_block_id] = logical_id
seq.block_table.append(logical_id)
return cpu_block_id
def get_last_cpu_block(self, seq: Sequence) -> int:
"""
获取序列最后一个 block 的 CPU block ID。
如果最后一个 block 不在 CPU 上,返回 -1。
Args:
seq: 序列
Returns:
CPU block ID如果不在 CPU 上则返回 -1
"""
if not seq.block_table:
return -1
last_logical_id = seq.block_table[-1]
block = self.logical_blocks[last_logical_id]
if block.location == BlockLocation.CPU:
return block.cpu_block_id
return -1
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
"""
获取 Ping-Pong decode 时新 KV 写入的 GPU slot。
策略:使用序列所需 chunks 数决定最后用的是 Ping 还是 Pong buffer
然后使用该 buffer 的最后一个 slot。
Args:
seq: 序列
Returns:
GPU slot ID
"""
cpu_blocks, _ = self.get_all_cpu_blocks(seq)
ping_size = self.offload_engine.ping_size
num_chunks = (len(cpu_blocks) + ping_size - 1) // ping_size if cpu_blocks else 0
# 最后一个 chunk 用的是哪个 buffer
if num_chunks % 2 == 1 or num_chunks == 0:
# 奇数个 chunk或0个最后用的是 ping
return self.offload_engine.ping_slots[-1]
else:
# 偶数个 chunk最后用的是 pong
return self.offload_engine.pong_slots[-1]
def __repr__(self) -> str:
return (
f"HybridKVCacheManager(\n"

View File

@@ -64,6 +64,14 @@ class OffloadEngine:
self.kv_dim = num_kv_heads * head_dim
self.block_numel = block_size * self.kv_dim
# ========== Ping-Pong 双缓冲配置 ==========
assert num_gpu_blocks >= 2, "Ping-Pong需要至少2个GPU blocks"
self.ping_size = num_gpu_blocks // 2
self.pong_size = num_gpu_blocks - self.ping_size
self.ping_slots = list(range(self.ping_size)) # [0, 1, 2, ...]
self.pong_slots = list(range(self.ping_size, num_gpu_blocks)) # [ping_size, ping_size+1, ...]
self.num_gpu_slots = num_gpu_blocks # alias
# ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
self.k_cache_gpu = torch.empty(
@@ -103,6 +111,17 @@ class OffloadEngine:
self.compute_stream = torch.cuda.current_stream()
self._stream_idx = 0
# ========== Ping-Pong 专用 stream 和事件 ==========
self.pingpong_stream = torch.cuda.Stream() # 专用于Ping-Pong传输
# 同步事件 - 加载完成
self.ping_ready = torch.cuda.Event()
self.pong_ready = torch.cuda.Event()
# 同步事件 - offload完成
self.ping_offload_done = torch.cuda.Event()
self.pong_offload_done = torch.cuda.Event()
# ========== Event tracking for async transfers ==========
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
@@ -516,7 +535,211 @@ class OffloadEngine:
f" kv_heads={self.num_kv_heads},\n"
f" head_dim={self.head_dim},\n"
f" dtype={self.dtype},\n"
f" ping_size={self.ping_size}, pong_size={self.pong_size},\n"
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
f")"
)
)
# ========== Ping-Pong 双缓冲方法 ==========
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
"""
异步加载CPU blocks到Ping buffer。
Args:
cpu_block_ids: 要加载的CPU block IDs列表
"""
if not cpu_block_ids:
self.ping_ready.record(self.pingpong_stream)
return
num_to_load = min(len(cpu_block_ids), self.ping_size)
logger.debug(f"Ping load: CPU{cpu_block_ids[:num_to_load]} -> GPU ping slots {self.ping_slots[:num_to_load]}")
with torch.cuda.stream(self.pingpong_stream):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.ping_slots[i]
# 所有层一起复制
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.ping_ready.record(self.pingpong_stream)
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
"""
异步加载CPU blocks到Pong buffer。
Args:
cpu_block_ids: 要加载的CPU block IDs列表
"""
if not cpu_block_ids:
self.pong_ready.record(self.pingpong_stream)
return
num_to_load = min(len(cpu_block_ids), self.pong_size)
logger.debug(f"Pong load: CPU{cpu_block_ids[:num_to_load]} -> GPU pong slots {self.pong_slots[:num_to_load]}")
with torch.cuda.stream(self.pingpong_stream):
for i in range(num_to_load):
cpu_id = cpu_block_ids[i]
gpu_slot = self.pong_slots[i]
self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True
)
self.v_cache_gpu[:, gpu_slot].copy_(
self.v_cache_cpu[:, cpu_id], non_blocking=True
)
self.pong_ready.record(self.pingpong_stream)
def wait_ping(self) -> None:
"""等待Ping buffer加载完成。"""
self.compute_stream.wait_event(self.ping_ready)
def wait_pong(self) -> None:
"""等待Pong buffer加载完成。"""
self.compute_stream.wait_event(self.pong_ready)
def offload_buffer_to_cpu(
self,
buffer: str,
cpu_block_ids: List[int],
) -> None:
"""
异步将buffer中的KV offload到CPU。
Args:
buffer: "ping""pong"
cpu_block_ids: 目标CPU block IDs列表
"""
slots = self.ping_slots if buffer == "ping" else self.pong_slots
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
if not cpu_block_ids:
event.record(self.pingpong_stream)
return
num_to_offload = min(len(cpu_block_ids), len(slots))
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
with torch.cuda.stream(self.pingpong_stream):
# 等待计算完成
self.pingpong_stream.wait_stream(self.compute_stream)
for i in range(num_to_offload):
gpu_slot = slots[i]
cpu_id = cpu_block_ids[i]
self.k_cache_cpu[:, cpu_id].copy_(
self.k_cache_gpu[:, gpu_slot], non_blocking=True
)
self.v_cache_cpu[:, cpu_id].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True
)
event.record(self.pingpong_stream)
def offload_slot_to_cpu(
self,
gpu_slot: int,
cpu_block_id: int,
) -> None:
"""
异步将单个GPU slot的KV offload到CPU。
Args:
gpu_slot: GPU slot ID
cpu_block_id: 目标CPU block ID
"""
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
with torch.cuda.stream(self.pingpong_stream):
self.pingpong_stream.wait_stream(self.compute_stream)
self.k_cache_cpu[:, cpu_block_id].copy_(
self.k_cache_gpu[:, gpu_slot], non_blocking=True
)
self.v_cache_cpu[:, cpu_block_id].copy_(
self.v_cache_gpu[:, gpu_slot], non_blocking=True
)
def wait_ping_offload_done(self) -> None:
"""等待Ping buffer offload完成。"""
self.compute_stream.wait_event(self.ping_offload_done)
def wait_pong_offload_done(self) -> None:
"""等待Pong buffer offload完成。"""
self.compute_stream.wait_event(self.pong_offload_done)
def wait_all_offload_done(self) -> None:
"""等待所有offload完成。"""
self.pingpong_stream.synchronize()
def get_kv_for_ping_slots(
self,
layer_id: int,
num_slots: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Ping buffer中指定数量slots的KV。
Args:
layer_id: 层ID
num_slots: 需要的slot数量
Returns:
(k_cache, v_cache)shape: [1, num_slots * block_size, kv_heads, head_dim]
"""
slots = self.ping_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
v = self.v_cache_gpu[layer_id, slots]
# Reshape: [num_slots, block_size, heads, dim] -> [1, num_slots*block_size, heads, dim]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_pong_slots(
self,
layer_id: int,
num_slots: int,
) -> Tuple[Tensor, Tensor]:
"""
获取Pong buffer中指定数量slots的KV。
Args:
layer_id: 层ID
num_slots: 需要的slot数量
Returns:
(k_cache, v_cache)shape: [1, num_slots * block_size, kv_heads, head_dim]
"""
slots = self.pong_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots]
v = self.v_cache_gpu[layer_id, slots]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v
def get_kv_for_slots(
self,
layer_id: int,
gpu_slots: List[int],
) -> Tuple[Tensor, Tensor]:
"""
获取指定GPU slots的KV。
Args:
layer_id: 层ID
gpu_slots: GPU slot IDs列表
Returns:
(k_cache, v_cache)shape: [1, len(slots) * block_size, kv_heads, head_dim]
"""
if not gpu_slots:
return None, None
k = self.k_cache_gpu[layer_id, gpu_slots]
v = self.v_cache_gpu[layer_id, gpu_slots]
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v

View File

@@ -97,51 +97,89 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute attention with chunked KV from CPU cache.
Compute attention with Ping-Pong dual buffer for chunked prefill.
For chunked prefill:
1. Load previous KV from CPU for this layer
2. Compute attention against previous KV (no causal mask)
1. Load previous KV from CPU using Ping-Pong (if any previous chunks)
2. Compute attention against previous KV chunks (no causal mask)
3. Compute attention against current chunk's KV (causal)
4. Merge results using online softmax
4. Merge all results using online softmax
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# q, k, v shape: [total_tokens, num_heads, head_dim]
total_tokens = q.shape[0]
# Reshape for flash attention: [batch, seq, heads, dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
accumulated_o = None
accumulated_lse = None
o_acc = None
lse_acc = None
# Load previous KV from CPU for this layer
if context.offload_engine is not None and self.layer_id >= 0:
# Get the kvcache_manager from context
kvcache_manager = context.offload_engine
# Load previous KV from CPU using Ping-Pong
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
# For each sequence in the chunk, load previous KV
# Currently assuming single sequence
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
prev_k, prev_v = kvcache_manager.load_prev_kv_for_layer(
context.chunked_seq,
self.layer_id,
)
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks already written in previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if prev_k is not None and prev_v is not None:
# Compute attention against previous KV (no causal mask)
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
ping_size = offload_engine.ping_size
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
current_buffer = "ping"
# Prefetch first chunk to Ping buffer
first_chunk_end = min(ping_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
offload_engine.load_to_ping(first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * ping_size
end = min(start + ping_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Prefetch next chunk to OTHER buffer
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + ping_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if current_buffer == "ping":
offload_engine.load_to_pong(next_chunk_ids)
else:
offload_engine.load_to_ping(next_chunk_ids)
# Wait for current buffer and get KV
if current_buffer == "ping":
offload_engine.wait_ping()
prev_k, prev_v = offload_engine.get_kv_for_ping_slots(
self.layer_id, num_blocks_in_chunk
)
else:
offload_engine.wait_pong()
prev_k, prev_v = offload_engine.get_kv_for_pong_slots(
self.layer_id, num_blocks_in_chunk
)
# Compute attention against this chunk (no causal mask)
prev_o, prev_lse = flash_attn_with_lse(
q_batched,
prev_k,
prev_v,
softmax_scale=self.scale,
causal=False, # No causal mask for previous context
causal=False,
)
accumulated_o = prev_o
accumulated_lse = prev_lse
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Switch buffer
current_buffer = "pong" if current_buffer == "ping" else "ping"
# Compute attention against current chunk's KV (with causal mask)
current_o, current_lse = flash_attn_with_lse(
@@ -149,17 +187,14 @@ class Attention(nn.Module):
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True, # Causal mask for current chunk
causal=True,
)
# Merge with accumulated
if accumulated_o is None:
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(
accumulated_o, accumulated_lse,
current_o, current_lse,
)
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -172,12 +207,13 @@ class Attention(nn.Module):
context,
) -> torch.Tensor:
"""
Compute decode attention with KV spread across CPU and GPU.
Compute decode attention with Ping-Pong dual buffer.
Uses chunked attention similar to chunked prefill:
1. Process blocks on GPU first (if any)
2. Load CPU blocks in chunks to GPU slots (per-layer)
3. Compute attention for each chunk, merge with online softmax
All KV is stored on CPU. Uses Ping-Pong buffers on GPU:
1. Load first chunk to Ping buffer
2. While computing on current buffer, prefetch next chunk to other buffer
3. Alternate between Ping and Pong buffers
4. Merge attention outputs using online softmax (LSE)
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -185,62 +221,73 @@ class Attention(nn.Module):
# Need: [batch, seqlen, heads, dim]
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
seq = context.chunked_seq
# Get all CPU blocks for this sequence
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
# Get the actual offload_engine for Ping-Pong operations
offload_engine = kvcache_manager.offload_engine
# Calculate chunk info
ping_size = offload_engine.ping_size
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
o_acc = None
lse_acc = None
current_buffer = "ping"
# Step 1: Process blocks already on GPU (if any)
gpu_slots, _ = kvcache_manager.get_gpu_blocks_for_decode(seq)
if gpu_slots:
k_gpu, v_gpu = kvcache_manager.get_kv_for_gpu_slots(self.layer_id, gpu_slots)
o_gpu, lse_gpu = flash_attn_with_lse(
q_batched, k_gpu, v_gpu,
# Prefetch first chunk to Ping buffer (loads all layers at once)
first_chunk_end = min(ping_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
offload_engine.load_to_ping(first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * ping_size
end = min(start + ping_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Prefetch next chunk to OTHER buffer (overlapped with current computation)
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + ping_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if current_buffer == "ping":
offload_engine.load_to_pong(next_chunk_ids)
else:
offload_engine.load_to_ping(next_chunk_ids)
# Wait for current buffer to be ready and get KV
if current_buffer == "ping":
offload_engine.wait_ping()
k_chunk, v_chunk = offload_engine.get_kv_for_ping_slots(
self.layer_id, num_blocks_in_chunk
)
else:
offload_engine.wait_pong()
k_chunk, v_chunk = offload_engine.get_kv_for_pong_slots(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
o_acc, lse_acc = o_gpu, lse_gpu
# Step 2: Process CPU blocks in chunks
# Get chunk info from kvcache_manager
cpu_block_ids, cpu_logical_ids, num_chunks = kvcache_manager.get_decode_chunk_info(seq)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
if num_chunks > 0:
# Use num_gpu_slots - 1 to avoid the reserved slot (used for write block)
chunk_size = kvcache_manager.num_gpu_slots - 1
for chunk_idx in range(num_chunks):
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(cpu_block_ids))
chunk_cpu_ids = cpu_block_ids[start:end]
# Load this chunk to GPU slots 0, 1, 2, ... for THIS LAYER
# (slot num_gpu_slots-1 is reserved for write block)
gpu_slots_for_chunk = list(range(len(chunk_cpu_ids)))
kvcache_manager.offload_engine.load_cpu_blocks_to_gpu_slots(
self.layer_id,
chunk_cpu_ids,
gpu_slots_for_chunk,
)
# Get KV for this chunk
k_chunk, v_chunk = kvcache_manager.get_kv_for_gpu_slots(
self.layer_id, gpu_slots_for_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Switch buffer for next iteration
current_buffer = "pong" if current_buffer == "ping" else "ping"
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")