[refactor] Translate into english, void Chinese due to claude.

This commit is contained in:
Zijie Tian
2025-12-11 00:30:24 +08:00
parent e85c2b4776
commit babfa17354
9 changed files with 297 additions and 187 deletions

View File

@@ -22,7 +22,7 @@ class Config:
offload_policy: str = "lru" # "lru", "fifo", or full class path offload_policy: str = "lru" # "lru", "fifo", or full class path
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available) num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
num_prefetch_blocks: int = 2 # Prefetch区的block数量,用于三区域GPU Buffer设计 num_prefetch_blocks: int = 2 # Number of prefetch blocks for three-region GPU buffer design
# Computed fields for offload (set in __post_init__ or by ModelRunner) # Computed fields for offload (set in __post_init__ or by ModelRunner)
num_gpu_kvcache_blocks: int = -1 num_gpu_kvcache_blocks: int = -1

View File

@@ -123,9 +123,9 @@ class ModelRunner:
num_gpu_blocks = max_gpu_blocks num_gpu_blocks = max_gpu_blocks
if config.enable_cpu_offload: if config.enable_cpu_offload:
# Ping-Pong设计CPU是主存储GPU是工作缓冲区 # Three-region design: CPU is primary storage, GPU is working buffer
# CPU blocks = 支持max_model_len所需的全部blocks存储一个最大序列的完整KV # CPU blocks = all blocks needed to support max_model_len (stores complete KV for one max sequence)
# GPU blocks = Ping-Pong工作缓冲区用户指定或自动 # GPU blocks = three-region working buffer (user-specified or auto)
num_cpu_blocks = (config.max_model_len + self.block_size - 1) // self.block_size num_cpu_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
config.num_gpu_kvcache_blocks = num_gpu_blocks config.num_gpu_kvcache_blocks = num_gpu_blocks
@@ -412,12 +412,12 @@ class ModelRunner:
def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool: def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool:
""" """
Check if 三区域 mode should be used. Check if three-region mode should be used.
Use 三区域 when: Use three-region when:
- CPU offload is enabled - CPU offload is enabled
- There are blocks on CPU (either allocated there or offloaded) - There are blocks on CPU (either allocated there or offloaded)
- Sequence exceeds GPU Compute capacity - Sequence exceeds GPU Compute region capacity
""" """
if not hasattr(self.kvcache_manager, 'offload_engine'): if not hasattr(self.kvcache_manager, 'offload_engine'):
return False return False
@@ -429,10 +429,10 @@ class ModelRunner:
# Check if any blocks are on CPU # Check if any blocks are on CPU
cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq) cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq)
if cpu_blocks: if cpu_blocks:
# Has CPU blocks - use 三区域 # Has CPU blocks - use three-region
return True return True
# Check if sequence needs more blocks than GPU Compute can hold # Check if sequence needs more blocks than GPU Compute region can hold
compute_size = self.kvcache_manager.offload_engine.num_compute_blocks compute_size = self.kvcache_manager.offload_engine.num_compute_blocks
if seq.num_blocks > compute_size: if seq.num_blocks > compute_size:
# Needs chunked processing # Needs chunked processing
@@ -630,17 +630,17 @@ class ModelRunner:
def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]: def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]:
""" """
Run prefill with 三区域 GPU buffer (CPU is primary storage). Run prefill with three-region GPU buffer (CPU is primary storage).
Flow: Flow:
1. All blocks are allocated to CPU (primary storage) 1. All blocks are allocated to CPU (primary storage)
2. Process tokens in chunks using Compute GPU buffer 2. Process tokens in chunks using Compute region GPU buffer
3. After each chunk, offload from Compute to CPU 3. After each chunk, offload from Compute region to CPU
4. Prefetch区 用于加载 previous KV(如果有的话) 4. Prefetch region is used to load previous KV (if any)
""" """
import sys import sys
assert len(seqs) == 1, "三区域 prefill only supports single sequence" assert len(seqs) == 1, "Three-region prefill only supports single sequence"
seq = seqs[0] seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine offload_engine = self.kvcache_manager.offload_engine
@@ -648,7 +648,7 @@ class ModelRunner:
tokens_per_chunk = compute_size * self.block_size tokens_per_chunk = compute_size * self.block_size
total_tokens = len(seq) total_tokens = len(seq)
print(f"[三区域 Prefill] Starting: {total_tokens} tokens, " print(f"[Three-region Prefill] Starting: {total_tokens} tokens, "
f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens", f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens",
file=sys.stderr) file=sys.stderr)
@@ -670,12 +670,12 @@ class ModelRunner:
end_block_idx = (chunk_end + self.block_size - 1) // self.block_size end_block_idx = (chunk_end + self.block_size - 1) // self.block_size
num_blocks = end_block_idx - start_block_idx num_blocks = end_block_idx - start_block_idx
print(f"[三区域 Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, " print(f"[Three-region Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, "
f"blocks {start_block_idx}-{end_block_idx-1}, " f"blocks {start_block_idx}-{end_block_idx-1}, "
f"compute_slots={offload_engine.compute_slots[:num_blocks]}", f"compute_slots={offload_engine.compute_slots[:num_blocks]}",
file=sys.stderr) file=sys.stderr)
# Get GPU slots for this chunk (使用 Compute) # Get GPU slots for this chunk (using Compute region)
gpu_slots = offload_engine.compute_slots[:num_blocks] gpu_slots = offload_engine.compute_slots[:num_blocks]
# Prepare inputs # Prepare inputs
@@ -695,7 +695,7 @@ class ModelRunner:
logical_id = seq.block_table[i] logical_id = seq.block_table[i]
self.kvcache_manager.prefilled_blocks.add(logical_id) self.kvcache_manager.prefilled_blocks.add(logical_id)
# Offload this chunk from Compute to CPU (async) # Offload this chunk from Compute region to CPU (async)
chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx] chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx]
offload_engine.offload_compute_to_cpu(chunk_cpu_blocks) offload_engine.offload_compute_to_cpu(chunk_cpu_blocks)
@@ -707,7 +707,7 @@ class ModelRunner:
# Wait for all offloads to complete # Wait for all offloads to complete
offload_engine.wait_all_offload_done() offload_engine.wait_all_offload_done()
print(f"[三区域 Prefill] Complete: {chunk_num} chunks", file=sys.stderr) print(f"[Three-region Prefill] Complete: {chunk_num} chunks", file=sys.stderr)
# Sample from last logits # Sample from last logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
@@ -776,14 +776,15 @@ class ModelRunner:
def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]: def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]:
""" """
Run decode with 三区域 GPU buffer. Run decode with three-region GPU buffer.
All KV is on CPU. Uses Decode to write new KV, Compute/Prefetch to load KV chunks. All KV is on CPU. Uses Decode region to write new KV, Compute/Prefetch region to load KV chunks.
New token's KV is written to Decode (slot 0) then offloaded to CPU. New token's KV is written to Decode region (slot 0) then offloaded to CPU only when block is full.
关键:Decode区 永远不会被 Compute/Prefetch 覆盖,专门用于写入新KV Key: Decode region is never overwritten by Compute/Prefetch, dedicated to writing new KV.
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
""" """
assert len(seqs) == 1, "三区域 decode only supports single sequence" assert len(seqs) == 1, "Three-region decode only supports single sequence"
seq = seqs[0] seq = seqs[0]
offload_engine = self.kvcache_manager.offload_engine offload_engine = self.kvcache_manager.offload_engine
@@ -792,13 +793,16 @@ class ModelRunner:
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
# 使用 Decode (slot 0) 写入新 KV # Use Decode region (slot 0) to write new KV
decode_slot = offload_engine.decode_slot # = 0 decode_slot = offload_engine.decode_slot # = 0
pos_in_block = (len(seq) - 1) % self.block_size pos_in_block = (len(seq) - 1) % self.block_size
slot = decode_slot * self.block_size + pos_in_block slot = decode_slot * self.block_size + pos_in_block
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Get decode start position for accumulated token tracking
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
# Set up context for chunked decode # Set up context for chunked decode
set_context( set_context(
is_prefill=False, is_prefill=False,
@@ -808,17 +812,22 @@ class ModelRunner:
offload_engine=self.kvcache_manager, offload_engine=self.kvcache_manager,
chunked_seq=seq, chunked_seq=seq,
decode_pos_in_block=pos_in_block, decode_pos_in_block=pos_in_block,
decode_start_pos_in_block=decode_start_pos,
) )
# Run model forward pass # Run model forward pass
logits = self.run_model(input_ids, positions, is_prefill=False) logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context() reset_context()
# Offload new KV from Decode区 to CPU # Only offload when block is full (pos_in_block == block_size - 1)
# This avoids unnecessary offloading on every decode step
if pos_in_block == self.block_size - 1:
last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq) last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq)
if last_cpu_block >= 0: if last_cpu_block >= 0:
offload_engine.offload_decode_slot(last_cpu_block) offload_engine.offload_decode_slot(last_cpu_block)
offload_engine.wait_all_offload_done() offload_engine.wait_all_offload_done()
# Reset decode start position for next block
self.kvcache_manager.reset_decode_start_pos(seq)
# Sample # Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None temperatures = self.prepare_sample(seqs) if self.rank == 0 else None

View File

@@ -95,16 +95,16 @@ class HybridKVCacheManager(KVCacheManager):
num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage) num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage)
block_size: Tokens per block block_size: Tokens per block
policy: Eviction policy (default: LRU) policy: Eviction policy (default: LRU)
cpu_primary: If True, use CPU as primary storage with 三区域 GPU buffer. cpu_primary: If True, use CPU as primary storage with three-region GPU buffer.
If False, use GPU as primary with CPU as overflow (legacy mode). If False, use GPU as primary with CPU as overflow (legacy mode).
num_prefetch_blocks: Number of prefetch blocks for 三区域 GPU buffer design num_prefetch_blocks: Number of prefetch blocks for three-region GPU buffer design
""" """
self._block_size = block_size self._block_size = block_size
self.num_gpu_slots = num_gpu_slots self.num_gpu_slots = num_gpu_slots
self.num_cpu_blocks = num_cpu_blocks self.num_cpu_blocks = num_cpu_blocks
self.total_blocks = num_gpu_slots + num_cpu_blocks self.total_blocks = num_gpu_slots + num_cpu_blocks
self.cpu_primary = cpu_primary # 三区域 mode flag self.cpu_primary = cpu_primary # Three-region mode flag
self.num_prefetch_blocks = num_prefetch_blocks # 三区域设计参数 self.num_prefetch_blocks = num_prefetch_blocks # Three-region design parameter
# Eviction policy # Eviction policy
self.policy = policy or LRUPolicy() self.policy = policy or LRUPolicy()
@@ -138,6 +138,10 @@ class HybridKVCacheManager(KVCacheManager):
# Track blocks that have been prefilled (KV written) for chunked prefill # Track blocks that have been prefilled (KV written) for chunked prefill
self.prefilled_blocks: Set[int] = set() # logical_ids self.prefilled_blocks: Set[int] = set() # logical_ids
# Track decode starting position within block (for batched offload optimization)
# Key: sequence id, Value: starting position where decode began in current block
self._decode_start_pos: Dict[int, int] = {}
@property @property
def block_size(self) -> int: def block_size(self) -> int:
return self._block_size return self._block_size
@@ -337,11 +341,11 @@ class HybridKVCacheManager(KVCacheManager):
""" """
assert not seq.block_table, "Sequence already has blocks" assert not seq.block_table, "Sequence already has blocks"
# Ping-Pong模式所有blocks都分配到CPU # Three-region mode: all blocks are allocated to CPU
if self.cpu_primary: if self.cpu_primary:
return self.allocate_cpu_only(seq) return self.allocate_cpu_only(seq)
# Legacy模式GPU为主CPU为overflow # Legacy mode: GPU as primary, CPU as overflow
h = -1 h = -1
cache_miss = False cache_miss = False
@@ -467,7 +471,7 @@ class HybridKVCacheManager(KVCacheManager):
block.token_ids = [] block.token_ids = []
if self.cpu_primary: if self.cpu_primary:
# Ping-Pong模式新block分配到CPU # Three-region mode: new block allocated to CPU
if not self.free_cpu_blocks: if not self.free_cpu_blocks:
raise RuntimeError("No free CPU blocks for decode") raise RuntimeError("No free CPU blocks for decode")
cpu_block_id = self.free_cpu_blocks.popleft() cpu_block_id = self.free_cpu_blocks.popleft()
@@ -476,7 +480,7 @@ class HybridKVCacheManager(KVCacheManager):
block.gpu_slot = -1 block.gpu_slot = -1
self.cpu_block_to_logical[cpu_block_id] = logical_id self.cpu_block_to_logical[cpu_block_id] = logical_id
else: else:
# Legacy模式新block分配到GPU # Legacy mode: new block allocated to GPU
gpu_slot = self._allocate_gpu_slot() gpu_slot = self._allocate_gpu_slot()
block.location = BlockLocation.GPU block.location = BlockLocation.GPU
block.gpu_slot = gpu_slot block.gpu_slot = gpu_slot
@@ -1021,22 +1025,22 @@ class HybridKVCacheManager(KVCacheManager):
break break
return pos return pos
# ========== Ping-Pong 双缓冲支持 ========== # ========== Three-region double buffering support ==========
def allocate_cpu_only(self, seq: Sequence) -> None: def allocate_cpu_only(self, seq: Sequence) -> None:
""" """
为序列分配 CPU blocks(用于 Ping-Pong 模式)。 Allocate CPU blocks for sequence (for three-region mode).
allocate() 不同,这里所有 blocks 都分配到 CPU Unlike allocate(), here all blocks are allocated to CPU,
GPU 只用作工作缓冲区。 GPU is only used as working buffer.
Args: Args:
seq: 要分配的序列 seq: Sequence to allocate
""" """
assert not seq.block_table, "Sequence already has blocks" assert not seq.block_table, "Sequence already has blocks"
for i in range(seq.num_blocks): for i in range(seq.num_blocks):
# 分配 CPU block # Allocate CPU block
if not self.free_cpu_blocks: if not self.free_cpu_blocks:
raise RuntimeError( raise RuntimeError(
f"No free CPU blocks. Need {seq.num_blocks}, " f"No free CPU blocks. Need {seq.num_blocks}, "
@@ -1045,7 +1049,7 @@ class HybridKVCacheManager(KVCacheManager):
cpu_block_id = self.free_cpu_blocks.popleft() cpu_block_id = self.free_cpu_blocks.popleft()
# 分配逻辑块 # Allocate logical block
logical_id = self.free_logical_ids.popleft() logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id] block = self.logical_blocks[logical_id]
block.ref_count = 1 block.ref_count = 1
@@ -1058,13 +1062,13 @@ class HybridKVCacheManager(KVCacheManager):
def get_cpu_block_table(self, seq: Sequence) -> List[int]: def get_cpu_block_table(self, seq: Sequence) -> List[int]:
""" """
获取序列的 CPU block ID 列表。 Get CPU block ID list for sequence.
Args: Args:
seq: 序列 seq: Sequence
Returns: Returns:
CPU block IDs 列表,按序列顺序 List of CPU block IDs in sequence order
""" """
cpu_blocks = [] cpu_blocks = []
for logical_id in seq.block_table: for logical_id in seq.block_table:
@@ -1072,20 +1076,20 @@ class HybridKVCacheManager(KVCacheManager):
if block.location == BlockLocation.CPU: if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id) cpu_blocks.append(block.cpu_block_id)
else: else:
# 如果 block 在 GPU 上,它应该有一个对应的 CPU block # If block is on GPU, it should have a corresponding CPU block
# 在 Ping-Pong 模式下,所有数据最终都在 CPU # In three-region mode, all data ultimately resides on CPU
raise RuntimeError( raise RuntimeError(
f"Block {logical_id} not on CPU (location={block.location}). " f"Block {logical_id} not on CPU (location={block.location}). "
f"In Ping-Pong mode, all blocks should be on CPU." f"In three-region mode, all blocks should be on CPU."
) )
return cpu_blocks return cpu_blocks
def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]: def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]:
""" """
获取序列的所有 CPU blocks 及其逻辑 ID。 Get all CPU blocks and their logical IDs for sequence.
Args: Args:
seq: 序列 seq: Sequence
Returns: Returns:
(cpu_block_ids, logical_ids) (cpu_block_ids, logical_ids)
@@ -1101,13 +1105,13 @@ class HybridKVCacheManager(KVCacheManager):
def allocate_next_cpu_block(self, seq: Sequence) -> int: def allocate_next_cpu_block(self, seq: Sequence) -> int:
""" """
为序列分配下一个 CPU block用于 decode 时新 token Allocate next CPU block for sequence (for new token during decode).
Args: Args:
seq: 序列 seq: Sequence
Returns: Returns:
新分配的 CPU block ID Newly allocated CPU block ID
""" """
if not self.free_cpu_blocks: if not self.free_cpu_blocks:
raise RuntimeError("No free CPU blocks") raise RuntimeError("No free CPU blocks")
@@ -1128,15 +1132,15 @@ class HybridKVCacheManager(KVCacheManager):
def get_last_cpu_block(self, seq: Sequence) -> int: def get_last_cpu_block(self, seq: Sequence) -> int:
""" """
获取序列最后一个 block 的 CPU block ID。 Get CPU block ID of the last block in sequence.
如果最后一个 block 不在 CPU 上,返回 -1。 Returns -1 if the last block is not on CPU.
Args: Args:
seq: 序列 seq: Sequence
Returns: Returns:
CPU block ID,如果不在 CPU 上则返回 -1 CPU block ID, or -1 if not on CPU
""" """
if not seq.block_table: if not seq.block_table:
return -1 return -1
@@ -1150,19 +1154,65 @@ class HybridKVCacheManager(KVCacheManager):
def get_write_slot_for_pingpong(self, seq: Sequence) -> int: def get_write_slot_for_pingpong(self, seq: Sequence) -> int:
""" """
获取三区域 decode 时新 KV 写入的 GPU slot。 Get GPU slot for writing new KV during three-region decode.
在三区域设计中,永远使用 Decode区 (slot 0) 写入新 KV In three-region design, always use Decode region (slot 0) to write new KV.
这样可以避免与 Compute/Prefetch区 的加载操作冲突。 This avoids conflicts with Compute/Prefetch region loading operations.
Args: Args:
seq: 序列 seq: Sequence
Returns: Returns:
GPU slot ID (永远是 decode_slot = 0) GPU slot ID (always decode_slot = 0)
""" """
return self.offload_engine.decode_slot return self.offload_engine.decode_slot
def get_decode_start_pos(self, seq: Sequence) -> int:
"""
Get the starting position within block where decode tokens began.
This is used for batched offload optimization - we need to attend to all
accumulated tokens in decode slot, not just the current one.
Args:
seq: Sequence
Returns:
Starting position within block (0 to block_size-1)
"""
seq_id = id(seq)
if seq_id not in self._decode_start_pos:
# First decode step - compute starting position
# After prefill, the last block has some tokens filled
# Decode starts at the next position
prefill_len = len(seq) - 1 # Current len includes the new decode token
self._decode_start_pos[seq_id] = prefill_len % self._block_size
return self._decode_start_pos[seq_id]
def reset_decode_start_pos(self, seq: Sequence) -> None:
"""
Reset decode start position for sequence.
Called when block is full and offloaded - next decode starts at position 0.
Args:
seq: Sequence
"""
seq_id = id(seq)
self._decode_start_pos[seq_id] = 0
def clear_decode_tracking(self, seq: Sequence) -> None:
"""
Clear decode position tracking for sequence.
Called when sequence is deallocated.
Args:
seq: Sequence
"""
seq_id = id(seq)
self._decode_start_pos.pop(seq_id, None)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"HybridKVCacheManager(\n" f"HybridKVCacheManager(\n"

View File

@@ -65,44 +65,44 @@ class OffloadEngine:
self.kv_dim = num_kv_heads * head_dim self.kv_dim = num_kv_heads * head_dim
self.block_numel = block_size * self.kv_dim self.block_numel = block_size * self.kv_dim
# ========== 三区域 GPU Buffer 配置 ========== # ========== Three-region GPU Buffer configuration ==========
# 约束检查 # Constraint checks
assert num_gpu_blocks >= 3, \ assert num_gpu_blocks >= 3, \
f"至少需要3个GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}" f"Need at least 3 GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}"
assert num_prefetch_blocks >= 1, \ assert num_prefetch_blocks >= 1, \
f"至少需要1个prefetch block, got {num_prefetch_blocks}" f"Need at least 1 prefetch block, got {num_prefetch_blocks}"
assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \ assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \
f"GPU blocks不足: 需要 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}" f"Insufficient GPU blocks: need 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}"
# 三区域配置 # Three-region configuration
# Decode: [0] - 固定1个block用于写入新KV # Decode region: [0] - Fixed 1 block for writing new KV
self.decode_slot = 0 self.decode_slot = 0
# Compute: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1] # Compute region: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1]
compute_start = 1 compute_start = 1
compute_end = num_gpu_blocks - num_prefetch_blocks compute_end = num_gpu_blocks - num_prefetch_blocks
self.compute_slots = list(range(compute_start, compute_end)) self.compute_slots = list(range(compute_start, compute_end))
self.num_compute_blocks = len(self.compute_slots) self.num_compute_blocks = len(self.compute_slots)
# Prefetch: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1] # Prefetch region: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1]
prefetch_start = compute_end prefetch_start = compute_end
self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks)) self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks))
self.num_prefetch_blocks = num_prefetch_blocks self.num_prefetch_blocks = num_prefetch_blocks
self.num_gpu_slots = num_gpu_blocks # alias self.num_gpu_slots = num_gpu_blocks # alias
# 保留旧的ping/pong属性以兼容后续会移除 # Keep old ping/pong attributes for compatibility (will be removed later)
self.ping_size = self.num_compute_blocks self.ping_size = self.num_compute_blocks
self.pong_size = self.num_prefetch_blocks self.pong_size = self.num_prefetch_blocks
self.ping_slots = self.compute_slots.copy() self.ping_slots = self.compute_slots.copy()
self.pong_slots = self.prefetch_slots.copy() self.pong_slots = self.prefetch_slots.copy()
logger.info(f"三区域 GPU Buffer: decode_slot={self.decode_slot}, " logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, "
f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}") f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}")
# ========== Fixed-address GPU KV cache ========== # ========== Fixed-address GPU KV cache ==========
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] # Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
# 使用 zeros 初始化以避免未初始化内存问题 # Use zeros initialization to avoid uninitialized memory issues
self.k_cache_gpu = torch.zeros( self.k_cache_gpu = torch.zeros(
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim, num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cuda" dtype=dtype, device="cuda"
@@ -140,15 +140,15 @@ class OffloadEngine:
self.compute_stream = torch.cuda.current_stream() self.compute_stream = torch.cuda.current_stream()
self._stream_idx = 0 self._stream_idx = 0
# ========== 三区域专用 stream 和事件 ========== # ========== Three-region dedicated stream and events ==========
self.transfer_stream_main = torch.cuda.Stream() # 主传输stream self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream
# 同步事件 - 三区域加载完成 # Sync events - three-region loading completion
self.compute_ready = torch.cuda.Event() self.compute_ready = torch.cuda.Event()
self.prefetch_ready = torch.cuda.Event() self.prefetch_ready = torch.cuda.Event()
self.decode_offload_done = torch.cuda.Event() self.decode_offload_done = torch.cuda.Event()
# 保留旧的ping/pong事件以兼容后续会移除 # Keep old ping/pong events for compatibility (will be removed later)
self.pingpong_stream = self.transfer_stream_main self.pingpong_stream = self.transfer_stream_main
self.ping_ready = self.compute_ready self.ping_ready = self.compute_ready
self.pong_ready = self.prefetch_ready self.pong_ready = self.prefetch_ready
@@ -568,20 +568,20 @@ class OffloadEngine:
f" kv_heads={self.num_kv_heads},\n" f" kv_heads={self.num_kv_heads},\n"
f" head_dim={self.head_dim},\n" f" head_dim={self.head_dim},\n"
f" dtype={self.dtype},\n" f" dtype={self.dtype},\n"
f" 三区域: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n" f" three_regions: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n"
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n" f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n" f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
f")" f")"
) )
# ========== Ping-Pong 双缓冲方法 ========== # ========== Ping-Pong double buffering methods ==========
def load_to_ping(self, cpu_block_ids: List[int]) -> None: def load_to_ping(self, cpu_block_ids: List[int]) -> None:
""" """
异步加载CPU blocksPing buffer Async load CPU blocks to Ping buffer.
Args: Args:
cpu_block_ids: 要加载的CPU block IDs列表 cpu_block_ids: List of CPU block IDs to load
""" """
if not cpu_block_ids: if not cpu_block_ids:
self.ping_ready.record(self.pingpong_stream) self.ping_ready.record(self.pingpong_stream)
@@ -594,7 +594,7 @@ class OffloadEngine:
for i in range(num_to_load): for i in range(num_to_load):
cpu_id = cpu_block_ids[i] cpu_id = cpu_block_ids[i]
gpu_slot = self.ping_slots[i] gpu_slot = self.ping_slots[i]
# 所有层一起复制 # Copy all layers together
self.k_cache_gpu[:, gpu_slot].copy_( self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True self.k_cache_cpu[:, cpu_id], non_blocking=True
) )
@@ -605,10 +605,10 @@ class OffloadEngine:
def load_to_pong(self, cpu_block_ids: List[int]) -> None: def load_to_pong(self, cpu_block_ids: List[int]) -> None:
""" """
异步加载CPU blocksPong buffer Async load CPU blocks to Pong buffer.
Args: Args:
cpu_block_ids: 要加载的CPU block IDs列表 cpu_block_ids: List of CPU block IDs to load
""" """
if not cpu_block_ids: if not cpu_block_ids:
self.pong_ready.record(self.pingpong_stream) self.pong_ready.record(self.pingpong_stream)
@@ -630,11 +630,11 @@ class OffloadEngine:
self.pong_ready.record(self.pingpong_stream) self.pong_ready.record(self.pingpong_stream)
def wait_ping(self) -> None: def wait_ping(self) -> None:
"""等待Ping buffer加载完成。""" """Wait for Ping buffer loading to complete."""
self.compute_stream.wait_event(self.ping_ready) self.compute_stream.wait_event(self.ping_ready)
def wait_pong(self) -> None: def wait_pong(self) -> None:
"""等待Pong buffer加载完成。""" """Wait for Pong buffer loading to complete."""
self.compute_stream.wait_event(self.pong_ready) self.compute_stream.wait_event(self.pong_ready)
def offload_buffer_to_cpu( def offload_buffer_to_cpu(
@@ -643,11 +643,11 @@ class OffloadEngine:
cpu_block_ids: List[int], cpu_block_ids: List[int],
) -> None: ) -> None:
""" """
异步将buffer中的KV offload到CPU Async offload KV from buffer to CPU.
Args: Args:
buffer: "ping" "pong" buffer: "ping" or "pong"
cpu_block_ids: 目标CPU block IDs列表 cpu_block_ids: Target CPU block IDs list
""" """
slots = self.ping_slots if buffer == "ping" else self.pong_slots slots = self.ping_slots if buffer == "ping" else self.pong_slots
event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done
@@ -660,7 +660,7 @@ class OffloadEngine:
logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}") logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
with torch.cuda.stream(self.pingpong_stream): with torch.cuda.stream(self.pingpong_stream):
# 等待计算完成 # Wait for compute to complete
self.pingpong_stream.wait_stream(self.compute_stream) self.pingpong_stream.wait_stream(self.compute_stream)
for i in range(num_to_offload): for i in range(num_to_offload):
@@ -680,11 +680,11 @@ class OffloadEngine:
cpu_block_id: int, cpu_block_id: int,
) -> None: ) -> None:
""" """
异步将单个GPU slot的KV offload到CPU Async offload a single GPU slot's KV to CPU.
Args: Args:
gpu_slot: GPU slot ID gpu_slot: GPU slot ID
cpu_block_id: 目标CPU block ID cpu_block_id: Target CPU block ID
""" """
logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]") logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]")
@@ -698,15 +698,15 @@ class OffloadEngine:
) )
def wait_ping_offload_done(self) -> None: def wait_ping_offload_done(self) -> None:
"""等待Ping buffer offload完成。""" """Wait for Ping buffer offload to complete."""
self.compute_stream.wait_event(self.ping_offload_done) self.compute_stream.wait_event(self.ping_offload_done)
def wait_pong_offload_done(self) -> None: def wait_pong_offload_done(self) -> None:
"""等待Pong buffer offload完成。""" """Wait for Pong buffer offload to complete."""
self.compute_stream.wait_event(self.pong_offload_done) self.compute_stream.wait_event(self.pong_offload_done)
def wait_all_offload_done(self) -> None: def wait_all_offload_done(self) -> None:
"""等待所有offload完成。""" """Wait for all offload operations to complete."""
self.pingpong_stream.synchronize() self.pingpong_stream.synchronize()
def get_kv_for_ping_slots( def get_kv_for_ping_slots(
@@ -715,14 +715,14 @@ class OffloadEngine:
num_slots: int, num_slots: int,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
获取Ping buffer中指定数量slots的KV。 Get KV for specified number of slots in Ping buffer.
Args: Args:
layer_id: ID layer_id: Layer ID
num_slots: 需要的slot数量 num_slots: Number of slots needed
Returns: Returns:
(k_cache, v_cache)shape: [1, num_slots * block_size, kv_heads, head_dim] (k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
""" """
slots = self.ping_slots[:num_slots] slots = self.ping_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim] k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim]
@@ -738,14 +738,14 @@ class OffloadEngine:
num_slots: int, num_slots: int,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
获取Pong buffer中指定数量slots的KV。 Get KV for specified number of slots in Pong buffer.
Args: Args:
layer_id: ID layer_id: Layer ID
num_slots: 需要的slot数量 num_slots: Number of slots needed
Returns: Returns:
(k_cache, v_cache)shape: [1, num_slots * block_size, kv_heads, head_dim] (k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim]
""" """
slots = self.pong_slots[:num_slots] slots = self.pong_slots[:num_slots]
k = self.k_cache_gpu[layer_id, slots] k = self.k_cache_gpu[layer_id, slots]
@@ -760,14 +760,14 @@ class OffloadEngine:
gpu_slots: List[int], gpu_slots: List[int],
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
获取指定GPU slots的KV。 Get KV for specified GPU slots.
Args: Args:
layer_id: ID layer_id: Layer ID
gpu_slots: GPU slot IDs列表 gpu_slots: List of GPU slot IDs
Returns: Returns:
(k_cache, v_cache)shape: [1, len(slots) * block_size, kv_heads, head_dim] (k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim]
""" """
if not gpu_slots: if not gpu_slots:
return None, None return None, None
@@ -777,14 +777,14 @@ class OffloadEngine:
v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) v = v.reshape(1, -1, self.num_kv_heads, self.head_dim)
return k, v return k, v
# ========== 三区域 GPU Buffer 方法 ========== # ========== Three-region GPU Buffer methods ==========
def load_to_compute(self, cpu_block_ids: List[int]) -> None: def load_to_compute(self, cpu_block_ids: List[int]) -> None:
""" """
异步加载CPU blocksCompute区。 Async load CPU blocks to Compute region.
Args: Args:
cpu_block_ids: 要加载的CPU block IDs列表 cpu_block_ids: List of CPU block IDs to load
""" """
if not cpu_block_ids: if not cpu_block_ids:
self.compute_ready.record(self.transfer_stream_main) self.compute_ready.record(self.transfer_stream_main)
@@ -797,7 +797,7 @@ class OffloadEngine:
for i in range(num_to_load): for i in range(num_to_load):
cpu_id = cpu_block_ids[i] cpu_id = cpu_block_ids[i]
gpu_slot = self.compute_slots[i] gpu_slot = self.compute_slots[i]
# 所有层一起复制 # Copy all layers together
self.k_cache_gpu[:, gpu_slot].copy_( self.k_cache_gpu[:, gpu_slot].copy_(
self.k_cache_cpu[:, cpu_id], non_blocking=True self.k_cache_cpu[:, cpu_id], non_blocking=True
) )
@@ -808,10 +808,10 @@ class OffloadEngine:
def load_to_prefetch(self, cpu_block_ids: List[int]) -> None: def load_to_prefetch(self, cpu_block_ids: List[int]) -> None:
""" """
异步加载CPU blocksPrefetch区。 Async load CPU blocks to Prefetch region.
Args: Args:
cpu_block_ids: 要加载的CPU block IDs列表 cpu_block_ids: List of CPU block IDs to load
""" """
if not cpu_block_ids: if not cpu_block_ids:
self.prefetch_ready.record(self.transfer_stream_main) self.prefetch_ready.record(self.transfer_stream_main)
@@ -833,25 +833,25 @@ class OffloadEngine:
self.prefetch_ready.record(self.transfer_stream_main) self.prefetch_ready.record(self.transfer_stream_main)
def wait_compute(self) -> None: def wait_compute(self) -> None:
"""等待Compute区加载完成。""" """Wait for Compute region loading to complete."""
self.compute_stream.wait_event(self.compute_ready) self.compute_stream.wait_event(self.compute_ready)
def wait_prefetch(self) -> None: def wait_prefetch(self) -> None:
"""等待Prefetch区加载完成。""" """Wait for Prefetch region loading to complete."""
self.compute_stream.wait_event(self.prefetch_ready) self.compute_stream.wait_event(self.prefetch_ready)
def swap_compute_prefetch(self) -> None: def swap_compute_prefetch(self) -> None:
"""交换Compute区和Prefetch区的角色。""" """Swap roles of Compute region and Prefetch region."""
self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots
# 同时更新旧的ping/pong slots以保持兼容 # Also update old ping/pong slots for compatibility
self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots
def offload_decode_slot(self, cpu_block_id: int) -> None: def offload_decode_slot(self, cpu_block_id: int) -> None:
""" """
将Decode区的KV offload到CPU Offload KV from Decode region to CPU.
Args: Args:
cpu_block_id: 目标CPU block ID cpu_block_id: Target CPU block ID
""" """
logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]") logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]")
@@ -866,7 +866,7 @@ class OffloadEngine:
self.decode_offload_done.record(self.transfer_stream_main) self.decode_offload_done.record(self.transfer_stream_main)
def wait_decode_offload(self) -> None: def wait_decode_offload(self) -> None:
"""等待Decode区offload完成。""" """Wait for Decode region offload to complete."""
self.compute_stream.wait_event(self.decode_offload_done) self.compute_stream.wait_event(self.decode_offload_done)
def get_kv_for_compute( def get_kv_for_compute(
@@ -875,14 +875,14 @@ class OffloadEngine:
num_blocks: int, num_blocks: int,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
获取Compute区中指定数量blocks的KV。 Get KV for specified number of blocks in Compute region.
Args: Args:
layer_id: ID layer_id: Layer ID
num_blocks: 需要的block数量 num_blocks: Number of blocks needed
Returns: Returns:
(k_cache, v_cache)shape: [1, num_blocks * block_size, kv_heads, head_dim] (k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
""" """
slots = self.compute_slots[:num_blocks] slots = self.compute_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim] k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim]
@@ -898,14 +898,14 @@ class OffloadEngine:
num_blocks: int, num_blocks: int,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
获取Prefetch区中指定数量blocks的KV。 Get KV for specified number of blocks in Prefetch region.
Args: Args:
layer_id: ID layer_id: Layer ID
num_blocks: 需要的block数量 num_blocks: Number of blocks needed
Returns: Returns:
(k_cache, v_cache)shape: [1, num_blocks * block_size, kv_heads, head_dim] (k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim]
""" """
slots = self.prefetch_slots[:num_blocks] slots = self.prefetch_slots[:num_blocks]
k = self.k_cache_gpu[layer_id, slots] k = self.k_cache_gpu[layer_id, slots]
@@ -920,14 +920,14 @@ class OffloadEngine:
pos_in_block: int, pos_in_block: int,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
""" """
获取Decode区指定位置的KV用于decode时的新token Get KV at specified position in Decode region (for new token during decode).
Args: Args:
layer_id: ID layer_id: Layer ID
pos_in_block: tokenblock内的位置 (0 to block_size-1) pos_in_block: Token position within block (0 to block_size-1)
Returns: Returns:
(k_cache, v_cache)shape: [1, 1, kv_heads, head_dim] (k_cache, v_cache), shape: [1, 1, kv_heads, head_dim]
""" """
k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim] k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim]
v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1]
@@ -935,12 +935,36 @@ class OffloadEngine:
v = v.unsqueeze(0) v = v.unsqueeze(0)
return k, v return k, v
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None: def get_kv_for_decode_slot_accumulated(
self,
layer_id: int,
num_tokens: int,
) -> Tuple[Tensor, Tensor]:
""" """
将Compute区的KV offload到CPU。 Get accumulated KV in Decode region (all tokens from position 0 to num_tokens-1).
Used when batching decode offloads - attend to all accumulated tokens,
not just the current one.
Args: Args:
cpu_block_ids: 目标CPU block IDs列表 layer_id: Layer ID
num_tokens: Number of accumulated tokens (1 to block_size)
Returns:
(k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim]
"""
k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] # [num_tokens, heads, dim]
v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens]
k = k.unsqueeze(0) # [1, num_tokens, heads, dim]
v = v.unsqueeze(0)
return k, v
def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None:
"""
Offload KV from Compute region to CPU.
Args:
cpu_block_ids: Target CPU block IDs list
""" """
if not cpu_block_ids: if not cpu_block_ids:
return return
@@ -949,7 +973,7 @@ class OffloadEngine:
logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}") logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}")
with torch.cuda.stream(self.transfer_stream_main): with torch.cuda.stream(self.transfer_stream_main):
# 等待计算完成 # Wait for compute to complete
self.transfer_stream_main.wait_stream(self.compute_stream) self.transfer_stream_main.wait_stream(self.compute_stream)
for i in range(num_to_offload): for i in range(num_to_offload):

View File

@@ -100,16 +100,16 @@ class Attention(nn.Module):
context, context,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute attention with 三区域 GPU buffer for chunked prefill. Compute attention with three-region GPU buffer for chunked prefill.
For chunked prefill: For chunked prefill:
1. Load previous KV from CPU using Compute/Prefetch (if any previous chunks) 1. Load previous KV from CPU using Compute/Prefetch region (if any previous chunks)
2. Compute attention against previous KV chunks (no causal mask) 2. Compute attention against previous KV chunks (no causal mask)
3. Compute attention against current chunk's KV (causal) 3. Compute attention against current chunk's KV (causal)
4. Merge all results using online softmax 4. Merge all results using online softmax
三区域设计保证当前chunk的KV在Compute区previous KV从CPU加载到Prefetch区 Three-region design guarantees: current chunk's KV is in Compute region, previous KV is loaded
不会发生写入和加载区域重叠的问题。 from CPU to Prefetch region, so write and load regions never overlap.
""" """
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -122,7 +122,7 @@ class Attention(nn.Module):
o_acc = None o_acc = None
lse_acc = None lse_acc = None
# Load previous KV from CPU using Compute/Prefetch # Load previous KV from CPU using Compute/Prefetch region
# Note: context.offload_engine is actually HybridKVCacheManager # Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine kvcache_manager = context.offload_engine
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
@@ -133,12 +133,12 @@ class Attention(nn.Module):
if cpu_block_table: if cpu_block_table:
offload_engine = kvcache_manager.offload_engine offload_engine = kvcache_manager.offload_engine
# 使用 Prefetch区 来加载 previous KV不会与当前 Compute区 冲突) # Use Prefetch region to load previous KV (won't conflict with current Compute region)
prefetch_size = offload_engine.num_prefetch_blocks prefetch_size = offload_engine.num_prefetch_blocks
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
use_compute = True # 交替使用 Compute区 和 Prefetch use_compute = True # Alternate between Compute region and Prefetch region
# 首先将 previous KV 加载到 Prefetch # First load previous KV to Prefetch region
# Only layer 0 triggers the load (loads ALL layers at once) # Only layer 0 triggers the load (loads ALL layers at once)
first_chunk_end = min(prefetch_size, len(cpu_block_table)) first_chunk_end = min(prefetch_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end] first_chunk_ids = cpu_block_table[:first_chunk_end]
@@ -157,14 +157,14 @@ class Attention(nn.Module):
next_end = min(next_start + prefetch_size, len(cpu_block_table)) next_end = min(next_start + prefetch_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end] next_chunk_ids = cpu_block_table[next_start:next_end]
if use_compute: if use_compute:
# 当前在 Prefetch区下一个加载到 Compute区如果有空间 # Currently in Prefetch region, next load to Compute region (if space available)
# 注意:Compute区 此时已写入当前chunk的KV不能覆盖 # Note: Compute region already has current chunk's KV written, cannot overwrite
# 所以这里我们使用简单的同步策略:等待当前完成后再加载 # So here we use simple sync strategy: wait for current to complete before loading
pass # 简化版本:不进行双缓冲,只用 Prefetch pass # Simplified version: no double buffering, only use Prefetch region
else: else:
offload_engine.load_to_prefetch(next_chunk_ids) offload_engine.load_to_prefetch(next_chunk_ids)
# Wait for Prefetch and get KV # Wait for Prefetch region and get KV
offload_engine.wait_prefetch() offload_engine.wait_prefetch()
prev_k, prev_v = offload_engine.get_kv_for_prefetch( prev_k, prev_v = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk self.layer_id, num_blocks_in_chunk
@@ -185,7 +185,7 @@ class Attention(nn.Module):
else: else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Load next chunk to Prefetch (if exists) # Load next chunk to Prefetch region (if exists)
if chunk_idx + 1 < num_chunks and self.layer_id == 0: if chunk_idx + 1 < num_chunks and self.layer_id == 0:
next_start = end next_start = end
next_end = min(next_start + prefetch_size, len(cpu_block_table)) next_end = min(next_start + prefetch_size, len(cpu_block_table))
@@ -218,16 +218,16 @@ class Attention(nn.Module):
context, context,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute decode attention with 三区域 GPU buffer. Compute decode attention with three-region GPU buffer.
All KV is stored on CPU. Uses Compute buffer on GPU: All KV is stored on CPU. Uses Compute region buffer on GPU:
1. Load chunk to Compute 1. Load chunk to Compute region
2. Compute attention 2. Compute attention
3. Repeat for all chunks 3. Repeat for all chunks
4. Finally, attend to Decode (slot 0) which contains the new token's KV 4. Finally, attend to Decode region (slot 0) which contains the new token's KV
5. Merge all attention outputs using online softmax (LSE) 5. Merge all attention outputs using online softmax (LSE)
关键新token的KV在Decode区(slot 0)不会被Compute区的加载覆盖。 Key: new token's KV is in Decode region (slot 0), won't be overwritten by Compute region loading.
""" """
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
@@ -246,10 +246,10 @@ class Attention(nn.Module):
if not cpu_block_table: if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available") raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
# Get the actual offload_engine for 三区域 operations # Get the actual offload_engine for three-region operations
offload_engine = kvcache_manager.offload_engine offload_engine = kvcache_manager.offload_engine
# Calculate chunk info using Compute # Calculate chunk info using Compute region
compute_size = offload_engine.num_compute_blocks compute_size = offload_engine.num_compute_blocks
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size
@@ -262,12 +262,12 @@ class Attention(nn.Module):
num_blocks_in_chunk = end - start num_blocks_in_chunk = end - start
chunk_ids = cpu_block_table[start:end] chunk_ids = cpu_block_table[start:end]
# Load this chunk to Compute # Load this chunk to Compute region
# Only layer 0 triggers the load (loads ALL layers at once) # Only layer 0 triggers the load (loads ALL layers at once)
if self.layer_id == 0: if self.layer_id == 0:
offload_engine.load_to_compute(chunk_ids) offload_engine.load_to_compute(chunk_ids)
# Wait for Compute to be ready and get KV # Wait for Compute region to be ready and get KV
offload_engine.wait_compute() offload_engine.wait_compute()
k_chunk, v_chunk = offload_engine.get_kv_for_compute( k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk self.layer_id, num_blocks_in_chunk
@@ -286,10 +286,20 @@ class Attention(nn.Module):
else: else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Now attend to Decode (contains the new token's KV) # Now attend to Decode region (contains accumulated decode tokens)
# This is the token being decoded - only 1 token at position pos_in_block # When batching offloads, decode slot accumulates multiple tokens
# from decode_start_pos_in_block to decode_pos_in_block (inclusive)
pos_in_block = context.decode_pos_in_block pos_in_block = context.decode_pos_in_block
decode_k, decode_v = offload_engine.get_kv_for_decode_slot(self.layer_id, pos_in_block) start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1
if num_accumulated > 0:
# Get accumulated KV in decode slot [start_pos : pos_in_block+1]
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0) # [1, num_tokens, heads, dim]
decode_v = decode_v.unsqueeze(0)
decode_o, decode_lse = flash_attn_with_lse( decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v, q_batched, decode_k, decode_v,
softmax_scale=self.scale, softmax_scale=self.scale,

View File

@@ -18,6 +18,8 @@ class RMSNorm(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
# Callers should reshape 3D tensors to 2D before calling
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.float() x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True) var = x.pow(2).mean(dim=-1, keepdim=True)
@@ -31,6 +33,7 @@ class RMSNorm(nn.Module):
x: torch.Tensor, x: torch.Tensor,
residual: torch.Tensor, residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
orig_dtype = x.dtype orig_dtype = x.dtype
x = x.float().add_(residual.float()) x = x.float().add_(residual.float())
residual = x.to(orig_dtype) residual = x.to(orig_dtype)

View File

@@ -79,8 +79,12 @@ class Qwen3Attention(nn.Module):
k = k.view(-1, self.num_kv_heads, self.head_dim) k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim)
if not self.qkv_bias: if not self.qkv_bias:
q = self.q_norm(q) # Reshape to 2D before RMSNorm to avoid torch.compile recompilation
k = self.k_norm(k) # q: [num_tokens, num_heads, head_dim] -> [num_tokens * num_heads, head_dim]
# After norm, reshape back to 3D
num_tokens = q.shape[0]
q = self.q_norm(q.reshape(-1, self.head_dim)).view(num_tokens, self.num_heads, self.head_dim)
k = self.k_norm(k.reshape(-1, self.head_dim)).view(num_tokens, self.num_kv_heads, self.head_dim)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
o = self.attn(q, k, v) o = self.attn(q, k, v)
output = self.o_proj(o.flatten(1, -1)) output = self.o_proj(o.flatten(1, -1))

View File

@@ -27,8 +27,11 @@ class Context:
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list) prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
# Current sequence being processed (for chunked prefill to load KV) # Current sequence being processed (for chunked prefill to load KV)
chunked_seq: Any = None chunked_seq: Any = None
# Position within block for decode (used for reading from Decode) # Position within block for decode (used for reading from Decode region)
decode_pos_in_block: int = 0 decode_pos_in_block: int = 0
# Starting position within block where decode tokens began (for accumulated token tracking)
# Used when batching decode offloads - we need to attend to all accumulated tokens
decode_start_pos_in_block: int = 0
_CONTEXT = Context() _CONTEXT = Context()
@@ -53,6 +56,7 @@ def set_context(
offload_engine=None, offload_engine=None,
chunked_seq=None, chunked_seq=None,
decode_pos_in_block=0, decode_pos_in_block=0,
decode_start_pos_in_block=0,
): ):
global _CONTEXT global _CONTEXT
_CONTEXT = Context( _CONTEXT = Context(
@@ -70,6 +74,7 @@ def set_context(
offload_engine=offload_engine, offload_engine=offload_engine,
chunked_seq=chunked_seq, chunked_seq=chunked_seq,
decode_pos_in_block=decode_pos_in_block, decode_pos_in_block=decode_pos_in_block,
decode_start_pos_in_block=decode_start_pos_in_block,
) )

View File

@@ -66,7 +66,7 @@ Attention mechanisms allow models to focus on relevant parts of the input.
return "".join(prompt_parts) return "".join(prompt_parts)
def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64): def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64, num_prefetch_blocks=2):
"""Test chunked prefill with limited GPU blocks.""" """Test chunked prefill with limited GPU blocks."""
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
@@ -75,15 +75,17 @@ def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64):
print(f"=" * 60) print(f"=" * 60)
print(f" target_input_len: ~{input_len} tokens") print(f" target_input_len: ~{input_len} tokens")
print(f" num_gpu_blocks: {num_gpu_blocks}") print(f" num_gpu_blocks: {num_gpu_blocks}")
print(f" num_prefetch_blocks: {num_prefetch_blocks}")
print() print()
llm = LLM( llm = LLM(
path, path,
enforce_eager=False, enforce_eager=False,
max_model_len=16 * 1024, max_model_len=128 * 1024,
max_num_batched_tokens=16 * 1024, max_num_batched_tokens=128 * 1024,
enable_cpu_offload=True, enable_cpu_offload=True,
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_prefetch_blocks=num_prefetch_blocks,
) )
print() print()
@@ -104,7 +106,7 @@ def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64):
return outputs return outputs
def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128): def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128, num_prefetch_blocks=2):
"""Test chunked decode with limited GPU blocks.""" """Test chunked decode with limited GPU blocks."""
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
@@ -114,15 +116,17 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128):
print(f" target_input_len: ~{input_len} tokens") print(f" target_input_len: ~{input_len} tokens")
print(f" output_len: {output_len} tokens") print(f" output_len: {output_len} tokens")
print(f" num_gpu_blocks: {num_gpu_blocks}") print(f" num_gpu_blocks: {num_gpu_blocks}")
print(f" num_prefetch_blocks: {num_prefetch_blocks}")
print() print()
llm = LLM( llm = LLM(
path, path,
enforce_eager=False, enforce_eager=False,
max_model_len=16 * 1024, max_model_len=128 * 1024,
max_num_batched_tokens=16 * 1024, max_num_batched_tokens=128 * 1024,
enable_cpu_offload=True, enable_cpu_offload=True,
num_gpu_blocks=num_gpu_blocks, num_gpu_blocks=num_gpu_blocks,
num_prefetch_blocks=num_prefetch_blocks,
) )
print() print()
@@ -144,9 +148,10 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128):
if __name__ == "__main__": if __name__ == "__main__":
# Parse arguments # Parse arguments: num_gpu_blocks input_len output_len [num_prefetch_blocks]
num_gpu_blocks = int(sys.argv[1]) if len(sys.argv) > 1 else 10 num_gpu_blocks = int(sys.argv[1]) if len(sys.argv) > 1 else 10
input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 2048 input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 2048
output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 64 output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 64
num_prefetch_blocks = int(sys.argv[4]) if len(sys.argv) > 4 else 2
test_chunked_prefill(num_gpu_blocks, input_len, output_len) test_chunked_prefill(num_gpu_blocks, input_len, output_len, num_prefetch_blocks)