[refactor] Refactor current gpu and cpu block allocation strategy.
This commit is contained in:
@@ -40,7 +40,6 @@ def main():
|
||||
max_model_len=128 * 1024,
|
||||
max_num_batched_tokens=128 * 1024,
|
||||
enable_cpu_offload=True,
|
||||
cpu_memory_gb=32.0,
|
||||
)
|
||||
|
||||
# Warmup
|
||||
|
||||
@@ -19,7 +19,6 @@ class Config:
|
||||
|
||||
# CPU Offload configuration
|
||||
enable_cpu_offload: bool = False
|
||||
cpu_memory_gb: float = 16.0 # CPU memory limit for KV cache
|
||||
offload_policy: str = "lru" # "lru", "fifo", or full class path
|
||||
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
|
||||
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
|
||||
|
||||
@@ -10,8 +10,11 @@ from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||
from nanovllm.layers.sampler import Sampler
|
||||
from nanovllm.utils.context import set_context, get_context, reset_context
|
||||
from nanovllm.utils.loader import load_model
|
||||
from nanovllm.utils.logger import get_logger
|
||||
from nanovllm.kvcache import create_kvcache_manager, KVCacheManager
|
||||
|
||||
logger = get_logger("model_runner")
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
|
||||
@@ -120,9 +123,11 @@ class ModelRunner:
|
||||
num_gpu_blocks = max_gpu_blocks
|
||||
|
||||
if config.enable_cpu_offload:
|
||||
# Calculate CPU blocks based on cpu_memory_gb
|
||||
cpu_bytes = int(config.cpu_memory_gb * 1024**3)
|
||||
num_cpu_blocks = cpu_bytes // block_bytes
|
||||
# Ping-Pong设计:CPU是主存储,GPU是工作缓冲区
|
||||
# CPU blocks = 支持max_model_len所需的全部blocks(存储一个最大序列的完整KV)
|
||||
# GPU blocks = Ping-Pong工作缓冲区(用户指定或自动)
|
||||
num_cpu_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
|
||||
|
||||
config.num_gpu_kvcache_blocks = num_gpu_blocks
|
||||
config.num_cpu_kvcache_blocks = num_cpu_blocks
|
||||
# For backward compatibility
|
||||
@@ -143,6 +148,27 @@ class ModelRunner:
|
||||
dtype=hf_config.torch_dtype,
|
||||
)
|
||||
|
||||
# Log KV cache allocation info
|
||||
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||
total_memory_mb = gpu_memory_mb + cpu_memory_mb
|
||||
|
||||
if config.enable_cpu_offload:
|
||||
logger.info(
|
||||
f"KV Cache allocated (Ping-Pong mode): "
|
||||
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
|
||||
f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), "
|
||||
f"Total={total_memory_mb:.1f}MB, "
|
||||
f"block_size={self.block_size}, "
|
||||
f"ping_size={config.num_gpu_kvcache_blocks // 2}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"KV Cache allocated: "
|
||||
f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), "
|
||||
f"block_size={self.block_size}"
|
||||
)
|
||||
|
||||
# Bind layer caches to attention modules and set layer_id
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
@@ -328,7 +354,16 @@ class ModelRunner:
|
||||
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
# Check if chunked prefill is needed
|
||||
# Check if Ping-Pong mode should be used (all blocks on CPU)
|
||||
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
|
||||
use_pingpong = self._should_use_pingpong(seqs, is_prefill)
|
||||
if use_pingpong:
|
||||
if is_prefill:
|
||||
return self.run_pingpong_prefill(seqs)
|
||||
else:
|
||||
return self.run_pingpong_decode(seqs)
|
||||
|
||||
# Check if chunked prefill is needed (legacy path)
|
||||
if is_prefill and hasattr(self, 'kvcache_manager'):
|
||||
needs_chunked = any(
|
||||
hasattr(self.kvcache_manager, 'needs_chunked_prefill') and
|
||||
@@ -338,7 +373,7 @@ class ModelRunner:
|
||||
if needs_chunked:
|
||||
return self.run_chunked_prefill(seqs)
|
||||
|
||||
# Check if chunked decode is needed
|
||||
# Check if chunked decode is needed (legacy path)
|
||||
if not is_prefill and hasattr(self, 'kvcache_manager'):
|
||||
needs_chunked = any(
|
||||
hasattr(self.kvcache_manager, 'needs_chunked_decode') and
|
||||
@@ -355,6 +390,36 @@ class ModelRunner:
|
||||
reset_context()
|
||||
return token_ids
|
||||
|
||||
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
|
||||
"""
|
||||
Check if Ping-Pong mode should be used.
|
||||
|
||||
Use Ping-Pong when:
|
||||
- CPU offload is enabled
|
||||
- There are blocks on CPU (either allocated there or offloaded)
|
||||
- Sequence exceeds GPU capacity
|
||||
"""
|
||||
if not hasattr(self.kvcache_manager, 'offload_engine'):
|
||||
return False
|
||||
|
||||
for seq in seqs:
|
||||
if not seq.block_table:
|
||||
continue # Skip warmup sequences
|
||||
|
||||
# Check if any blocks are on CPU
|
||||
cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if cpu_blocks:
|
||||
# Has CPU blocks - use Ping-Pong
|
||||
return True
|
||||
|
||||
# Check if sequence needs more blocks than GPU can hold
|
||||
ping_size = self.kvcache_manager.offload_engine.ping_size
|
||||
if seq.num_blocks > ping_size:
|
||||
# Needs chunked processing
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run prefill in chunks when sequences exceed GPU capacity.
|
||||
@@ -543,6 +608,210 @@ class ModelRunner:
|
||||
|
||||
return input_ids, positions
|
||||
|
||||
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run prefill with Ping-Pong dual buffer (CPU is primary storage).
|
||||
|
||||
Flow:
|
||||
1. All blocks are allocated to CPU (primary storage)
|
||||
2. Process tokens in chunks using Ping/Pong GPU buffers
|
||||
3. After each chunk, offload from GPU to CPU
|
||||
4. Alternate between Ping and Pong buffers
|
||||
"""
|
||||
import sys
|
||||
|
||||
assert len(seqs) == 1, "Ping-Pong prefill only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
offload_engine = self.kvcache_manager.offload_engine
|
||||
ping_size = offload_engine.ping_size
|
||||
tokens_per_chunk = ping_size * self.block_size
|
||||
|
||||
total_tokens = len(seq)
|
||||
print(f"[Ping-Pong Prefill] Starting: {total_tokens} tokens, "
|
||||
f"ping_size={ping_size} blocks, chunk={tokens_per_chunk} tokens",
|
||||
file=sys.stderr)
|
||||
|
||||
current_buffer = "ping"
|
||||
chunk_num = 0
|
||||
logits = None
|
||||
processed_tokens = 0
|
||||
|
||||
# Get CPU block table for offload targets
|
||||
cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq)
|
||||
|
||||
while processed_tokens < total_tokens:
|
||||
chunk_num += 1
|
||||
chunk_start = processed_tokens
|
||||
chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens)
|
||||
chunk_tokens = chunk_end - chunk_start
|
||||
|
||||
# Calculate which CPU blocks this chunk covers
|
||||
start_block_idx = chunk_start // self.block_size
|
||||
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
|
||||
num_blocks = end_block_idx - start_block_idx
|
||||
|
||||
print(f"[Ping-Pong Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
|
||||
f"blocks {start_block_idx}-{end_block_idx-1}, buffer={current_buffer}",
|
||||
file=sys.stderr)
|
||||
|
||||
# Get GPU slots for this chunk (Ping or Pong buffer)
|
||||
if current_buffer == "ping":
|
||||
gpu_slots = offload_engine.ping_slots[:num_blocks]
|
||||
else:
|
||||
gpu_slots = offload_engine.pong_slots[:num_blocks]
|
||||
|
||||
# Prepare inputs
|
||||
input_ids, positions = self._prepare_pingpong_chunk(
|
||||
seq, chunk_start, chunk_end, gpu_slots, start_block_idx
|
||||
)
|
||||
|
||||
if input_ids.numel() == 0:
|
||||
break
|
||||
|
||||
# Run model forward
|
||||
logits = self.run_model(input_ids, positions, is_prefill=True)
|
||||
reset_context()
|
||||
|
||||
# Mark blocks as prefilled
|
||||
for i in range(start_block_idx, min(end_block_idx, len(seq.block_table))):
|
||||
logical_id = seq.block_table[i]
|
||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||
|
||||
# Offload this chunk from GPU to CPU (async)
|
||||
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
|
||||
offload_engine.offload_buffer_to_cpu(current_buffer, chunk_cpu_blocks)
|
||||
|
||||
# Switch buffer for next chunk
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping_offload_done()
|
||||
current_buffer = "pong"
|
||||
else:
|
||||
offload_engine.wait_pong_offload_done()
|
||||
current_buffer = "ping"
|
||||
|
||||
processed_tokens = chunk_end
|
||||
|
||||
# Wait for all offloads to complete
|
||||
offload_engine.wait_all_offload_done()
|
||||
|
||||
print(f"[Ping-Pong Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
|
||||
|
||||
# Sample from last logits
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
if logits is not None:
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
else:
|
||||
token_ids = [0] if self.rank == 0 else None
|
||||
|
||||
return token_ids
|
||||
|
||||
def _prepare_pingpong_chunk(
|
||||
self,
|
||||
seq: Sequence,
|
||||
chunk_start: int,
|
||||
chunk_end: int,
|
||||
gpu_slots: list[int],
|
||||
start_block_idx: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Prepare inputs for a Ping-Pong prefill chunk."""
|
||||
# Input tokens for this chunk
|
||||
input_ids = seq[chunk_start:chunk_end]
|
||||
positions = list(range(chunk_start, chunk_end))
|
||||
|
||||
# Create slot mapping pointing to GPU slots
|
||||
slot_mapping = []
|
||||
num_tokens = chunk_end - chunk_start
|
||||
|
||||
token_idx = 0
|
||||
for i, gpu_slot in enumerate(gpu_slots):
|
||||
block_idx = start_block_idx + i
|
||||
block_start = block_idx * self.block_size
|
||||
block_end = min(block_start + self.block_size, len(seq))
|
||||
|
||||
# How many tokens in this block for this chunk
|
||||
overlap_start = max(chunk_start, block_start)
|
||||
overlap_end = min(chunk_end, block_end)
|
||||
|
||||
for pos in range(overlap_start, overlap_end):
|
||||
pos_in_block = pos % self.block_size
|
||||
slot = gpu_slot * self.block_size + pos_in_block
|
||||
slot_mapping.append(slot)
|
||||
|
||||
# Convert to tensors
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Set up context for chunked prefill
|
||||
seqlen = num_tokens
|
||||
cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
set_context(
|
||||
is_prefill=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=seqlen,
|
||||
max_seqlen_k=seqlen,
|
||||
slot_mapping=slot_mapping,
|
||||
is_chunked_prefill=True,
|
||||
offload_engine=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
)
|
||||
|
||||
return input_ids, positions
|
||||
|
||||
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run decode with Ping-Pong dual buffer.
|
||||
|
||||
All KV is on CPU. Uses Ping-Pong to load KV chunks and compute attention.
|
||||
New token's KV is written to GPU then offloaded to CPU.
|
||||
"""
|
||||
assert len(seqs) == 1, "Ping-Pong decode only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
# Prepare inputs
|
||||
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Get write slot for new KV (will use last slot of the buffer used for final chunk)
|
||||
write_slot = self.kvcache_manager.get_write_slot_for_pingpong(seq)
|
||||
|
||||
# Calculate position in block for slot mapping
|
||||
last_block_idx = seq.num_blocks - 1
|
||||
pos_in_block = (len(seq) - 1) % self.block_size
|
||||
slot = write_slot * self.block_size + pos_in_block
|
||||
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Set up context for chunked decode
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_len,
|
||||
is_chunked_prefill=True, # Use chunked attention path
|
||||
offload_engine=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
)
|
||||
|
||||
# Run model forward pass
|
||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||
reset_context()
|
||||
|
||||
# Offload new KV from write_slot to CPU
|
||||
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
|
||||
if last_cpu_block >= 0:
|
||||
self.kvcache_manager.offload_engine.offload_slot_to_cpu(write_slot, last_cpu_block)
|
||||
self.kvcache_manager.offload_engine.wait_all_offload_done()
|
||||
|
||||
# Sample
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
|
||||
return token_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_cudagraph(self):
|
||||
config = self.config
|
||||
|
||||
@@ -81,20 +81,24 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
policy: Optional[EvictionPolicy] = None,
|
||||
cpu_primary: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid manager.
|
||||
|
||||
Args:
|
||||
num_gpu_slots: Number of GPU buffer slots (working set)
|
||||
num_cpu_blocks: Number of CPU pool blocks (overflow)
|
||||
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU)
|
||||
cpu_primary: If True, use CPU as primary storage with Ping-Pong GPU buffer.
|
||||
If False, use GPU as primary with CPU as overflow (legacy mode).
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||
self.cpu_primary = cpu_primary # Ping-Pong mode flag
|
||||
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
@@ -321,12 +325,16 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
Allocate logical blocks for prefill.
|
||||
|
||||
New blocks are allocated on GPU when possible. If GPU is full and all
|
||||
GPU blocks belong to this sequence (can't evict), remaining blocks
|
||||
are allocated to CPU for chunked prefill.
|
||||
In cpu_primary mode (Ping-Pong): All blocks are allocated to CPU.
|
||||
In legacy mode: Blocks are allocated to GPU when possible, overflow to CPU.
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks"
|
||||
|
||||
# Ping-Pong模式:所有blocks都分配到CPU
|
||||
if self.cpu_primary:
|
||||
return self.allocate_cpu_only(seq)
|
||||
|
||||
# Legacy模式:GPU为主,CPU为overflow
|
||||
h = -1
|
||||
cache_miss = False
|
||||
|
||||
@@ -451,13 +459,22 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block.hash = -1
|
||||
block.token_ids = []
|
||||
|
||||
# New decode blocks go to GPU
|
||||
gpu_slot = self._allocate_gpu_slot()
|
||||
block.location = BlockLocation.GPU
|
||||
block.gpu_slot = gpu_slot
|
||||
|
||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
||||
self.policy.on_block_allocated(gpu_slot, self.current_step)
|
||||
if self.cpu_primary:
|
||||
# Ping-Pong模式:新block分配到CPU
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError("No free CPU blocks for decode")
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
block.location = BlockLocation.CPU
|
||||
block.cpu_block_id = cpu_block_id
|
||||
block.gpu_slot = -1
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
else:
|
||||
# Legacy模式:新block分配到GPU
|
||||
gpu_slot = self._allocate_gpu_slot()
|
||||
block.location = BlockLocation.GPU
|
||||
block.gpu_slot = gpu_slot
|
||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
||||
self.policy.on_block_allocated(gpu_slot, self.current_step)
|
||||
|
||||
block_table.append(logical_id)
|
||||
|
||||
@@ -993,6 +1010,158 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
break
|
||||
return pos
|
||||
|
||||
# ========== Ping-Pong 双缓冲支持 ==========
|
||||
|
||||
def allocate_cpu_only(self, seq: Sequence) -> None:
|
||||
"""
|
||||
为序列分配 CPU blocks(用于 Ping-Pong 模式)。
|
||||
|
||||
与 allocate() 不同,这里所有 blocks 都分配到 CPU,
|
||||
GPU 只用作工作缓冲区。
|
||||
|
||||
Args:
|
||||
seq: 要分配的序列
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks"
|
||||
|
||||
for i in range(seq.num_blocks):
|
||||
# 分配 CPU block
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError(
|
||||
f"No free CPU blocks. Need {seq.num_blocks}, "
|
||||
f"available: {len(self.free_cpu_blocks)}"
|
||||
)
|
||||
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
|
||||
# 分配逻辑块
|
||||
logical_id = self.free_logical_ids.popleft()
|
||||
block = self.logical_blocks[logical_id]
|
||||
block.ref_count = 1
|
||||
block.location = BlockLocation.CPU
|
||||
block.cpu_block_id = cpu_block_id
|
||||
block.gpu_slot = -1
|
||||
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
seq.block_table.append(logical_id)
|
||||
|
||||
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
|
||||
"""
|
||||
获取序列的 CPU block ID 列表。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
CPU block IDs 列表,按序列顺序
|
||||
"""
|
||||
cpu_blocks = []
|
||||
for logical_id in seq.block_table:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_blocks.append(block.cpu_block_id)
|
||||
else:
|
||||
# 如果 block 在 GPU 上,它应该有一个对应的 CPU block
|
||||
# 在 Ping-Pong 模式下,所有数据最终都在 CPU 上
|
||||
raise RuntimeError(
|
||||
f"Block {logical_id} not on CPU (location={block.location}). "
|
||||
f"In Ping-Pong mode, all blocks should be on CPU."
|
||||
)
|
||||
return cpu_blocks
|
||||
|
||||
def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]:
|
||||
"""
|
||||
获取序列的所有 CPU blocks 及其逻辑 ID。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
(cpu_block_ids, logical_ids)
|
||||
"""
|
||||
cpu_block_ids = []
|
||||
logical_ids = []
|
||||
for logical_id in seq.block_table:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_block_ids.append(block.cpu_block_id)
|
||||
logical_ids.append(logical_id)
|
||||
return cpu_block_ids, logical_ids
|
||||
|
||||
def allocate_next_cpu_block(self, seq: Sequence) -> int:
|
||||
"""
|
||||
为序列分配下一个 CPU block(用于 decode 时新 token)。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
新分配的 CPU block ID
|
||||
"""
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError("No free CPU blocks")
|
||||
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
logical_id = self.free_logical_ids.popleft()
|
||||
|
||||
block = self.logical_blocks[logical_id]
|
||||
block.ref_count = 1
|
||||
block.location = BlockLocation.CPU
|
||||
block.cpu_block_id = cpu_block_id
|
||||
block.gpu_slot = -1
|
||||
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
seq.block_table.append(logical_id)
|
||||
|
||||
return cpu_block_id
|
||||
|
||||
def get_last_cpu_block(self, seq: Sequence) -> int:
|
||||
"""
|
||||
获取序列最后一个 block 的 CPU block ID。
|
||||
|
||||
如果最后一个 block 不在 CPU 上,返回 -1。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
CPU block ID,如果不在 CPU 上则返回 -1
|
||||
"""
|
||||
if not seq.block_table:
|
||||
return -1
|
||||
|
||||
last_logical_id = seq.block_table[-1]
|
||||
block = self.logical_blocks[last_logical_id]
|
||||
|
||||
if block.location == BlockLocation.CPU:
|
||||
return block.cpu_block_id
|
||||
return -1
|
||||
|
||||
def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
|
||||
"""
|
||||
获取 Ping-Pong decode 时新 KV 写入的 GPU slot。
|
||||
|
||||
策略:使用序列所需 chunks 数决定最后用的是 Ping 还是 Pong buffer,
|
||||
然后使用该 buffer 的最后一个 slot。
|
||||
|
||||
Args:
|
||||
seq: 序列
|
||||
|
||||
Returns:
|
||||
GPU slot ID
|
||||
"""
|
||||
cpu_blocks, _ = self.get_all_cpu_blocks(seq)
|
||||
ping_size = self.offload_engine.ping_size
|
||||
num_chunks = (len(cpu_blocks) + ping_size - 1) // ping_size if cpu_blocks else 0
|
||||
|
||||
# 最后一个 chunk 用的是哪个 buffer
|
||||
if num_chunks % 2 == 1 or num_chunks == 0:
|
||||
# 奇数个 chunk(或0个),最后用的是 ping
|
||||
return self.offload_engine.ping_slots[-1]
|
||||
else:
|
||||
# 偶数个 chunk,最后用的是 pong
|
||||
return self.offload_engine.pong_slots[-1]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"HybridKVCacheManager(\n"
|
||||
|
||||
@@ -64,6 +64,14 @@ class OffloadEngine:
|
||||
self.kv_dim = num_kv_heads * head_dim
|
||||
self.block_numel = block_size * self.kv_dim
|
||||
|
||||
# ========== Ping-Pong 双缓冲配置 ==========
|
||||
assert num_gpu_blocks >= 2, "Ping-Pong需要至少2个GPU blocks"
|
||||
self.ping_size = num_gpu_blocks // 2
|
||||
self.pong_size = num_gpu_blocks - self.ping_size
|
||||
self.ping_slots = list(range(self.ping_size)) # [0, 1, 2, ...]
|
||||
self.pong_slots = list(range(self.ping_size, num_gpu_blocks)) # [ping_size, ping_size+1, ...]
|
||||
self.num_gpu_slots = num_gpu_blocks # alias
|
||||
|
||||
# ========== Fixed-address GPU KV cache ==========
|
||||
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
self.k_cache_gpu = torch.empty(
|
||||
@@ -103,6 +111,17 @@ class OffloadEngine:
|
||||
self.compute_stream = torch.cuda.current_stream()
|
||||
self._stream_idx = 0
|
||||
|
||||
# ========== Ping-Pong 专用 stream 和事件 ==========
|
||||
self.pingpong_stream = torch.cuda.Stream() # 专用于Ping-Pong传输
|
||||
|
||||
# 同步事件 - 加载完成
|
||||
self.ping_ready = torch.cuda.Event()
|
||||
self.pong_ready = torch.cuda.Event()
|
||||
|
||||
# 同步事件 - offload完成
|
||||
self.ping_offload_done = torch.cuda.Event()
|
||||
self.pong_offload_done = torch.cuda.Event()
|
||||
|
||||
# ========== Event tracking for async transfers ==========
|
||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||
|
||||
@@ -516,7 +535,211 @@ class OffloadEngine:
|
||||
f" kv_heads={self.num_kv_heads},\n"
|
||||
f" head_dim={self.head_dim},\n"
|
||||
f" dtype={self.dtype},\n"
|
||||
f" ping_size={self.ping_size}, pong_size={self.pong_size},\n"
|
||||
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
|
||||
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
|
||||
f")"
|
||||
)
|
||||
)
|
||||
|
||||
# ========== Ping-Pong 双缓冲方法 ==========
|
||||
|
||||
def load_to_ping(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Ping buffer。
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.ping_ready.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), self.ping_size)
|
||||
logger.debug(f"Ping load: CPU{cpu_block_ids[:num_to_load]} -> GPU ping slots {self.ping_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.ping_slots[i]
|
||||
# 所有层一起复制
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.ping_ready.record(self.pingpong_stream)
|
||||
|
||||
def load_to_pong(self, cpu_block_ids: List[int]) -> None:
|
||||
"""
|
||||
异步加载CPU blocks到Pong buffer。
|
||||
|
||||
Args:
|
||||
cpu_block_ids: 要加载的CPU block IDs列表
|
||||
"""
|
||||
if not cpu_block_ids:
|
||||
self.pong_ready.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_load = min(len(cpu_block_ids), self.pong_size)
|
||||
logger.debug(f"Pong load: CPU{cpu_block_ids[:num_to_load]} -> GPU pong slots {self.pong_slots[:num_to_load]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
for i in range(num_to_load):
|
||||
cpu_id = cpu_block_ids[i]
|
||||
gpu_slot = self.pong_slots[i]
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
self.k_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
self.v_cache_cpu[:, cpu_id], non_blocking=True
|
||||
)
|
||||
self.pong_ready.record(self.pingpong_stream)
|
||||
|
||||
def wait_ping(self) -> None:
|
||||
"""等待Ping buffer加载完成。"""
|
||||
self.compute_stream.wait_event(self.ping_ready)
|
||||
|
||||
def wait_pong(self) -> None:
|
||||
"""等待Pong buffer加载完成。"""
|
||||
self.compute_stream.wait_event(self.pong_ready)
|
||||
|
||||
def offload_buffer_to_cpu(
|
||||
self,
|
||||
buffer: str,
|
||||
cpu_block_ids: List[int],
|
||||
) -> None:
|
||||
"""
|
||||
异步将buffer中的KV offload到CPU。
|
||||
|
||||
Args:
|
||||
buffer: "ping" 或 "pong"
|
||||
cpu_block_ids: 目标CPU block IDs列表
|
||||
"""
|
||||
slots = self.ping_slots if buffer == "ping" else self.pong_slots
|
||||
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
|
||||
|
||||
if not cpu_block_ids:
|
||||
event.record(self.pingpong_stream)
|
||||
return
|
||||
|
||||
num_to_offload = min(len(cpu_block_ids), len(slots))
|
||||
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
# 等待计算完成
|
||||
self.pingpong_stream.wait_stream(self.compute_stream)
|
||||
|
||||
for i in range(num_to_offload):
|
||||
gpu_slot = slots[i]
|
||||
cpu_id = cpu_block_ids[i]
|
||||
self.k_cache_cpu[:, cpu_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
event.record(self.pingpong_stream)
|
||||
|
||||
def offload_slot_to_cpu(
|
||||
self,
|
||||
gpu_slot: int,
|
||||
cpu_block_id: int,
|
||||
) -> None:
|
||||
"""
|
||||
异步将单个GPU slot的KV offload到CPU。
|
||||
|
||||
Args:
|
||||
gpu_slot: GPU slot ID
|
||||
cpu_block_id: 目标CPU block ID
|
||||
"""
|
||||
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
|
||||
|
||||
with torch.cuda.stream(self.pingpong_stream):
|
||||
self.pingpong_stream.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, gpu_slot], non_blocking=True
|
||||
)
|
||||
|
||||
def wait_ping_offload_done(self) -> None:
|
||||
"""等待Ping buffer offload完成。"""
|
||||
self.compute_stream.wait_event(self.ping_offload_done)
|
||||
|
||||
def wait_pong_offload_done(self) -> None:
|
||||
"""等待Pong buffer offload完成。"""
|
||||
self.compute_stream.wait_event(self.pong_offload_done)
|
||||
|
||||
def wait_all_offload_done(self) -> None:
|
||||
"""等待所有offload完成。"""
|
||||
self.pingpong_stream.synchronize()
|
||||
|
||||
def get_kv_for_ping_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_slots: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Ping buffer中指定数量slots的KV。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_slots: 需要的slot数量
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.ping_slots[:num_slots]
|
||||
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
# Reshape: [num_slots, block_size, heads, dim] -> [1, num_slots*block_size, heads, dim]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_pong_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
num_slots: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取Pong buffer中指定数量slots的KV。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
num_slots: 需要的slot数量
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
slots = self.pong_slots[:num_slots]
|
||||
k = self.k_cache_gpu[layer_id, slots]
|
||||
v = self.v_cache_gpu[layer_id, slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
|
||||
def get_kv_for_slots(
|
||||
self,
|
||||
layer_id: int,
|
||||
gpu_slots: List[int],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
获取指定GPU slots的KV。
|
||||
|
||||
Args:
|
||||
layer_id: 层ID
|
||||
gpu_slots: GPU slot IDs列表
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache),shape: [1, len(slots) * block_size, kv_heads, head_dim]
|
||||
"""
|
||||
if not gpu_slots:
|
||||
return None, None
|
||||
k = self.k_cache_gpu[layer_id, gpu_slots]
|
||||
v = self.v_cache_gpu[layer_id, gpu_slots]
|
||||
k = k.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
|
||||
return k, v
|
||||
@@ -97,51 +97,89 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with chunked KV from CPU cache.
|
||||
Compute attention with Ping-Pong dual buffer for chunked prefill.
|
||||
|
||||
For chunked prefill:
|
||||
1. Load previous KV from CPU for this layer
|
||||
2. Compute attention against previous KV (no causal mask)
|
||||
1. Load previous KV from CPU using Ping-Pong (if any previous chunks)
|
||||
2. Compute attention against previous KV chunks (no causal mask)
|
||||
3. Compute attention against current chunk's KV (causal)
|
||||
4. Merge results using online softmax
|
||||
4. Merge all results using online softmax
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
||||
total_tokens = q.shape[0]
|
||||
|
||||
# Reshape for flash attention: [batch, seq, heads, dim]
|
||||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||||
k_batched = k.unsqueeze(0)
|
||||
v_batched = v.unsqueeze(0)
|
||||
|
||||
accumulated_o = None
|
||||
accumulated_lse = None
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
|
||||
# Load previous KV from CPU for this layer
|
||||
if context.offload_engine is not None and self.layer_id >= 0:
|
||||
# Get the kvcache_manager from context
|
||||
kvcache_manager = context.offload_engine
|
||||
# Load previous KV from CPU using Ping-Pong
|
||||
# Note: context.offload_engine is actually HybridKVCacheManager
|
||||
kvcache_manager = context.offload_engine
|
||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||
|
||||
# For each sequence in the chunk, load previous KV
|
||||
# Currently assuming single sequence
|
||||
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
|
||||
prev_k, prev_v = kvcache_manager.load_prev_kv_for_layer(
|
||||
context.chunked_seq,
|
||||
self.layer_id,
|
||||
)
|
||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||
# Get prefilled CPU blocks (blocks already written in previous chunks)
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
if prev_k is not None and prev_v is not None:
|
||||
# Compute attention against previous KV (no causal mask)
|
||||
if cpu_block_table:
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
ping_size = offload_engine.ping_size
|
||||
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
|
||||
current_buffer = "ping"
|
||||
|
||||
# Prefetch first chunk to Ping buffer
|
||||
first_chunk_end = min(ping_size, len(cpu_block_table))
|
||||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
||||
offload_engine.load_to_ping(first_chunk_ids)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * ping_size
|
||||
end = min(start + ping_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
|
||||
# Prefetch next chunk to OTHER buffer
|
||||
if chunk_idx + 1 < num_chunks:
|
||||
next_start = end
|
||||
next_end = min(next_start + ping_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
if current_buffer == "ping":
|
||||
offload_engine.load_to_pong(next_chunk_ids)
|
||||
else:
|
||||
offload_engine.load_to_ping(next_chunk_ids)
|
||||
|
||||
# Wait for current buffer and get KV
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping()
|
||||
prev_k, prev_v = offload_engine.get_kv_for_ping_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
else:
|
||||
offload_engine.wait_pong()
|
||||
prev_k, prev_v = offload_engine.get_kv_for_pong_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
|
||||
# Compute attention against this chunk (no causal mask)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
prev_k,
|
||||
prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # No causal mask for previous context
|
||||
causal=False,
|
||||
)
|
||||
accumulated_o = prev_o
|
||||
accumulated_lse = prev_lse
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
# Switch buffer
|
||||
current_buffer = "pong" if current_buffer == "ping" else "ping"
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
@@ -149,17 +187,14 @@ class Attention(nn.Module):
|
||||
k_batched,
|
||||
v_batched,
|
||||
softmax_scale=self.scale,
|
||||
causal=True, # Causal mask for current chunk
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if accumulated_o is None:
|
||||
if o_acc is None:
|
||||
final_o = current_o
|
||||
else:
|
||||
final_o, _ = merge_attention_outputs(
|
||||
accumulated_o, accumulated_lse,
|
||||
current_o, current_lse,
|
||||
)
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
|
||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||
return final_o.squeeze(0)
|
||||
@@ -172,12 +207,13 @@ class Attention(nn.Module):
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute decode attention with KV spread across CPU and GPU.
|
||||
Compute decode attention with Ping-Pong dual buffer.
|
||||
|
||||
Uses chunked attention similar to chunked prefill:
|
||||
1. Process blocks on GPU first (if any)
|
||||
2. Load CPU blocks in chunks to GPU slots (per-layer)
|
||||
3. Compute attention for each chunk, merge with online softmax
|
||||
All KV is stored on CPU. Uses Ping-Pong buffers on GPU:
|
||||
1. Load first chunk to Ping buffer
|
||||
2. While computing on current buffer, prefetch next chunk to other buffer
|
||||
3. Alternate between Ping and Pong buffers
|
||||
4. Merge attention outputs using online softmax (LSE)
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
@@ -185,62 +221,73 @@ class Attention(nn.Module):
|
||||
# Need: [batch, seqlen, heads, dim]
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
|
||||
# Note: context.offload_engine is actually HybridKVCacheManager
|
||||
kvcache_manager = context.offload_engine
|
||||
seq = context.chunked_seq
|
||||
|
||||
# Get all CPU blocks for this sequence
|
||||
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
||||
|
||||
# Get the actual offload_engine for Ping-Pong operations
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
|
||||
# Calculate chunk info
|
||||
ping_size = offload_engine.ping_size
|
||||
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
|
||||
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
current_buffer = "ping"
|
||||
|
||||
# Step 1: Process blocks already on GPU (if any)
|
||||
gpu_slots, _ = kvcache_manager.get_gpu_blocks_for_decode(seq)
|
||||
if gpu_slots:
|
||||
k_gpu, v_gpu = kvcache_manager.get_kv_for_gpu_slots(self.layer_id, gpu_slots)
|
||||
o_gpu, lse_gpu = flash_attn_with_lse(
|
||||
q_batched, k_gpu, v_gpu,
|
||||
# Prefetch first chunk to Ping buffer (loads all layers at once)
|
||||
first_chunk_end = min(ping_size, len(cpu_block_table))
|
||||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
||||
offload_engine.load_to_ping(first_chunk_ids)
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * ping_size
|
||||
end = min(start + ping_size, len(cpu_block_table))
|
||||
num_blocks_in_chunk = end - start
|
||||
|
||||
# Prefetch next chunk to OTHER buffer (overlapped with current computation)
|
||||
if chunk_idx + 1 < num_chunks:
|
||||
next_start = end
|
||||
next_end = min(next_start + ping_size, len(cpu_block_table))
|
||||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||||
if current_buffer == "ping":
|
||||
offload_engine.load_to_pong(next_chunk_ids)
|
||||
else:
|
||||
offload_engine.load_to_ping(next_chunk_ids)
|
||||
|
||||
# Wait for current buffer to be ready and get KV
|
||||
if current_buffer == "ping":
|
||||
offload_engine.wait_ping()
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_ping_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
else:
|
||||
offload_engine.wait_pong()
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_pong_slots(
|
||||
self.layer_id, num_blocks_in_chunk
|
||||
)
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
q_batched, k_chunk, v_chunk,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
o_acc, lse_acc = o_gpu, lse_gpu
|
||||
|
||||
# Step 2: Process CPU blocks in chunks
|
||||
# Get chunk info from kvcache_manager
|
||||
cpu_block_ids, cpu_logical_ids, num_chunks = kvcache_manager.get_decode_chunk_info(seq)
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = o_chunk, lse_chunk
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||
|
||||
if num_chunks > 0:
|
||||
# Use num_gpu_slots - 1 to avoid the reserved slot (used for write block)
|
||||
chunk_size = kvcache_manager.num_gpu_slots - 1
|
||||
|
||||
for chunk_idx in range(num_chunks):
|
||||
start = chunk_idx * chunk_size
|
||||
end = min(start + chunk_size, len(cpu_block_ids))
|
||||
chunk_cpu_ids = cpu_block_ids[start:end]
|
||||
|
||||
# Load this chunk to GPU slots 0, 1, 2, ... for THIS LAYER
|
||||
# (slot num_gpu_slots-1 is reserved for write block)
|
||||
gpu_slots_for_chunk = list(range(len(chunk_cpu_ids)))
|
||||
kvcache_manager.offload_engine.load_cpu_blocks_to_gpu_slots(
|
||||
self.layer_id,
|
||||
chunk_cpu_ids,
|
||||
gpu_slots_for_chunk,
|
||||
)
|
||||
|
||||
# Get KV for this chunk
|
||||
k_chunk, v_chunk = kvcache_manager.get_kv_for_gpu_slots(
|
||||
self.layer_id, gpu_slots_for_chunk
|
||||
)
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
q_batched, k_chunk, v_chunk,
|
||||
softmax_scale=self.scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = o_chunk, lse_chunk
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||||
# Switch buffer for next iteration
|
||||
current_buffer = "pong" if current_buffer == "ping" else "ping"
|
||||
|
||||
if o_acc is None:
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
@@ -14,63 +14,66 @@ os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
||||
from nanovllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=16):
|
||||
def create_long_context_prompt(target_tokens: int) -> str:
|
||||
"""
|
||||
Create a meaningful long context prompt with a question at the end.
|
||||
The answer depends on information scattered throughout the context.
|
||||
"""
|
||||
# Key facts to embed in the context
|
||||
facts = [
|
||||
"The capital of France is Paris.",
|
||||
"The Eiffel Tower was built in 1889.",
|
||||
"Python was created by Guido van Rossum.",
|
||||
"The speed of light is approximately 299,792 kilometers per second.",
|
||||
"Mount Everest is 8,848 meters tall.",
|
||||
]
|
||||
|
||||
# Padding text to reach target length
|
||||
padding_paragraph = """
|
||||
This is additional context information that helps extend the length of the prompt.
|
||||
Machine learning has revolutionized many fields including computer vision, natural language processing, and robotics.
|
||||
Deep neural networks can learn complex patterns from large amounts of data.
|
||||
The transformer architecture has become the foundation of modern language models.
|
||||
Attention mechanisms allow models to focus on relevant parts of the input.
|
||||
"""
|
||||
|
||||
# Build the prompt
|
||||
prompt_parts = []
|
||||
|
||||
# Add instruction
|
||||
prompt_parts.append("Please read the following information carefully and answer the question at the end.\n\n")
|
||||
|
||||
# Add facts at different positions
|
||||
current_tokens = 50 # approximate tokens so far
|
||||
tokens_per_padding = 80 # approximate tokens per padding paragraph
|
||||
fact_interval = target_tokens // (len(facts) + 1)
|
||||
|
||||
fact_idx = 0
|
||||
while current_tokens < target_tokens - 100:
|
||||
# Add padding
|
||||
prompt_parts.append(padding_paragraph)
|
||||
current_tokens += tokens_per_padding
|
||||
|
||||
# Add a fact at intervals
|
||||
if fact_idx < len(facts) and current_tokens > fact_interval * (fact_idx + 1):
|
||||
prompt_parts.append(f"\n[Important Fact #{fact_idx + 1}]: {facts[fact_idx]}\n")
|
||||
current_tokens += 20
|
||||
fact_idx += 1
|
||||
|
||||
# Add the question at the end
|
||||
prompt_parts.append("\n\nQuestion: Based on the information above, what is the capital of France and when was the Eiffel Tower built? Please answer briefly.\n\nAnswer:")
|
||||
|
||||
return "".join(prompt_parts)
|
||||
|
||||
|
||||
def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64):
|
||||
"""Test chunked prefill with limited GPU blocks."""
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
|
||||
total_blocks = (input_len + 255) // 256
|
||||
print(f"=" * 60)
|
||||
print(f"Chunked Prefill Test")
|
||||
print(f"Chunked Prefill Test (Ping-Pong)")
|
||||
print(f"=" * 60)
|
||||
print(f" input_len: {input_len} tokens")
|
||||
print(f" total_blocks: {total_blocks}")
|
||||
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
||||
print(f" blocks_on_cpu: {max(0, total_blocks - num_gpu_blocks)}")
|
||||
print()
|
||||
|
||||
llm = LLM(
|
||||
path,
|
||||
enforce_eager=False,
|
||||
max_model_len=16 * 1024, # 16K is enough for 8K test
|
||||
max_num_batched_tokens=16 * 1024,
|
||||
enable_cpu_offload=True,
|
||||
cpu_memory_gb=4.0,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
|
||||
print(f"LLM initialized:")
|
||||
print(f" num_gpu_kvcache_blocks: {llm.model_runner.config.num_gpu_kvcache_blocks}")
|
||||
print(f" num_cpu_kvcache_blocks: {llm.model_runner.config.num_cpu_kvcache_blocks}")
|
||||
print()
|
||||
|
||||
# Create prompt with approximate token count
|
||||
prompt = "Hello " * (input_len // 2)
|
||||
|
||||
print(f"Running generation...")
|
||||
outputs = llm.generate(
|
||||
[prompt],
|
||||
SamplingParams(temperature=0.6, max_tokens=output_len),
|
||||
use_tqdm=True,
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"Output tokens: {len(outputs[0]['token_ids'])}")
|
||||
print(f"Output text (first 100 chars): {outputs[0]['text'][:100]}")
|
||||
print()
|
||||
return outputs
|
||||
|
||||
|
||||
def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=64):
|
||||
"""Test chunked decode with limited GPU blocks."""
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
|
||||
total_blocks = (input_len + 255) // 256
|
||||
print(f"=" * 60)
|
||||
print(f"Chunked Decode Test")
|
||||
print(f"=" * 60)
|
||||
print(f" input_len: {input_len} tokens")
|
||||
print(f" output_len: {output_len} tokens")
|
||||
print(f" total_blocks: {total_blocks}")
|
||||
print(f" target_input_len: ~{input_len} tokens")
|
||||
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
||||
print()
|
||||
|
||||
@@ -80,27 +83,62 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=64):
|
||||
max_model_len=16 * 1024,
|
||||
max_num_batched_tokens=16 * 1024,
|
||||
enable_cpu_offload=True,
|
||||
cpu_memory_gb=4.0,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
|
||||
print(f"LLM initialized:")
|
||||
print(f" num_gpu_kvcache_blocks: {llm.model_runner.config.num_gpu_kvcache_blocks}")
|
||||
print(f" num_cpu_kvcache_blocks: {llm.model_runner.config.num_cpu_kvcache_blocks}")
|
||||
print()
|
||||
|
||||
prompt = "Hello " * (input_len // 2)
|
||||
# Create meaningful prompt
|
||||
prompt = create_long_context_prompt(input_len)
|
||||
|
||||
print(f"Running generation...")
|
||||
outputs = llm.generate(
|
||||
[prompt],
|
||||
SamplingParams(temperature=0.6, max_tokens=output_len),
|
||||
use_tqdm=True,
|
||||
SamplingParams(temperature=0.1, max_tokens=output_len), # low temperature for more deterministic output
|
||||
use_tqdm=False,
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"Output tokens: {len(outputs[0]['token_ids'])}")
|
||||
print(f"Output text (first 100 chars): {outputs[0]['text'][:100]}")
|
||||
print(f"Output text:\n{outputs[0]['text']}")
|
||||
print()
|
||||
return outputs
|
||||
|
||||
|
||||
def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128):
|
||||
"""Test chunked decode with limited GPU blocks."""
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
|
||||
print(f"=" * 60)
|
||||
print(f"Chunked Decode Test (Ping-Pong)")
|
||||
print(f"=" * 60)
|
||||
print(f" target_input_len: ~{input_len} tokens")
|
||||
print(f" output_len: {output_len} tokens")
|
||||
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
||||
print()
|
||||
|
||||
llm = LLM(
|
||||
path,
|
||||
enforce_eager=False,
|
||||
max_model_len=16 * 1024,
|
||||
max_num_batched_tokens=16 * 1024,
|
||||
enable_cpu_offload=True,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
print()
|
||||
|
||||
# Create meaningful prompt
|
||||
prompt = create_long_context_prompt(input_len)
|
||||
|
||||
print(f"Running generation...")
|
||||
outputs = llm.generate(
|
||||
[prompt],
|
||||
SamplingParams(temperature=0.1, max_tokens=output_len),
|
||||
use_tqdm=False,
|
||||
)
|
||||
|
||||
print()
|
||||
print(f"Output tokens: {len(outputs[0]['token_ids'])}")
|
||||
print(f"Output text:\n{outputs[0]['text']}")
|
||||
print()
|
||||
return outputs
|
||||
|
||||
@@ -108,7 +146,7 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=64):
|
||||
if __name__ == "__main__":
|
||||
# Parse arguments
|
||||
num_gpu_blocks = int(sys.argv[1]) if len(sys.argv) > 1 else 10
|
||||
input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 8192
|
||||
output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 32
|
||||
input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 2048
|
||||
output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 64
|
||||
|
||||
test_chunked_prefill(num_gpu_blocks, input_len, output_len)
|
||||
Reference in New Issue
Block a user