diff --git a/nanovllm/config.py b/nanovllm/config.py index 1bc27f0..13a8e29 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -22,7 +22,7 @@ class Config: offload_policy: str = "lru" # "lru", "fifo", or full class path num_transfer_streams: int = 4 # Number of CUDA streams for async transfers num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available) - num_prefetch_blocks: int = 2 # Prefetch区的block数量,用于三区域GPU Buffer设计 + num_prefetch_blocks: int = 2 # Number of prefetch blocks for three-region GPU buffer design # Computed fields for offload (set in __post_init__ or by ModelRunner) num_gpu_kvcache_blocks: int = -1 diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index db6dfe3..e7f3c8a 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -123,9 +123,9 @@ class ModelRunner: num_gpu_blocks = max_gpu_blocks if config.enable_cpu_offload: - # Ping-Pong设计:CPU是主存储,GPU是工作缓冲区 - # CPU blocks = 支持max_model_len所需的全部blocks(存储一个最大序列的完整KV) - # GPU blocks = Ping-Pong工作缓冲区(用户指定或自动) + # Three-region design: CPU is primary storage, GPU is working buffer + # CPU blocks = all blocks needed to support max_model_len (stores complete KV for one max sequence) + # GPU blocks = three-region working buffer (user-specified or auto) num_cpu_blocks = (config.max_model_len + self.block_size - 1) // self.block_size config.num_gpu_kvcache_blocks = num_gpu_blocks @@ -412,12 +412,12 @@ class ModelRunner: def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool: """ - Check if 三区域 mode should be used. + Check if three-region mode should be used. - Use 三区域 when: + Use three-region when: - CPU offload is enabled - There are blocks on CPU (either allocated there or offloaded) - - Sequence exceeds GPU Compute区 capacity + - Sequence exceeds GPU Compute region capacity """ if not hasattr(self.kvcache_manager, 'offload_engine'): return False @@ -429,10 +429,10 @@ class ModelRunner: # Check if any blocks are on CPU cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq) if cpu_blocks: - # Has CPU blocks - use 三区域 + # Has CPU blocks - use three-region return True - # Check if sequence needs more blocks than GPU Compute区 can hold + # Check if sequence needs more blocks than GPU Compute region can hold compute_size = self.kvcache_manager.offload_engine.num_compute_blocks if seq.num_blocks > compute_size: # Needs chunked processing @@ -630,17 +630,17 @@ class ModelRunner: def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]: """ - Run prefill with 三区域 GPU buffer (CPU is primary storage). + Run prefill with three-region GPU buffer (CPU is primary storage). Flow: 1. All blocks are allocated to CPU (primary storage) - 2. Process tokens in chunks using Compute区 GPU buffer - 3. After each chunk, offload from Compute区 to CPU - 4. Prefetch区 用于加载 previous KV(如果有的话) + 2. Process tokens in chunks using Compute region GPU buffer + 3. After each chunk, offload from Compute region to CPU + 4. Prefetch region is used to load previous KV (if any) """ import sys - assert len(seqs) == 1, "三区域 prefill only supports single sequence" + assert len(seqs) == 1, "Three-region prefill only supports single sequence" seq = seqs[0] offload_engine = self.kvcache_manager.offload_engine @@ -648,7 +648,7 @@ class ModelRunner: tokens_per_chunk = compute_size * self.block_size total_tokens = len(seq) - print(f"[三区域 Prefill] Starting: {total_tokens} tokens, " + print(f"[Three-region Prefill] Starting: {total_tokens} tokens, " f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens", file=sys.stderr) @@ -670,12 +670,12 @@ class ModelRunner: end_block_idx = (chunk_end + self.block_size - 1) // self.block_size num_blocks = end_block_idx - start_block_idx - print(f"[三区域 Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, " + print(f"[Three-region Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, " f"blocks {start_block_idx}-{end_block_idx-1}, " f"compute_slots={offload_engine.compute_slots[:num_blocks]}", file=sys.stderr) - # Get GPU slots for this chunk (使用 Compute区) + # Get GPU slots for this chunk (using Compute region) gpu_slots = offload_engine.compute_slots[:num_blocks] # Prepare inputs @@ -695,7 +695,7 @@ class ModelRunner: logical_id = seq.block_table[i] self.kvcache_manager.prefilled_blocks.add(logical_id) - # Offload this chunk from Compute区 to CPU (async) + # Offload this chunk from Compute region to CPU (async) chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx] offload_engine.offload_compute_to_cpu(chunk_cpu_blocks) @@ -707,7 +707,7 @@ class ModelRunner: # Wait for all offloads to complete offload_engine.wait_all_offload_done() - print(f"[三区域 Prefill] Complete: {chunk_num} chunks", file=sys.stderr) + print(f"[Three-region Prefill] Complete: {chunk_num} chunks", file=sys.stderr) # Sample from last logits temperatures = self.prepare_sample(seqs) if self.rank == 0 else None @@ -776,14 +776,15 @@ class ModelRunner: def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]: """ - Run decode with 三区域 GPU buffer. + Run decode with three-region GPU buffer. - All KV is on CPU. Uses Decode区 to write new KV, Compute/Prefetch区 to load KV chunks. - New token's KV is written to Decode区 (slot 0) then offloaded to CPU. + All KV is on CPU. Uses Decode region to write new KV, Compute/Prefetch region to load KV chunks. + New token's KV is written to Decode region (slot 0) then offloaded to CPU only when block is full. - 关键:Decode区 永远不会被 Compute/Prefetch 覆盖,专门用于写入新KV。 + Key: Decode region is never overwritten by Compute/Prefetch, dedicated to writing new KV. + Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens. """ - assert len(seqs) == 1, "三区域 decode only supports single sequence" + assert len(seqs) == 1, "Three-region decode only supports single sequence" seq = seqs[0] offload_engine = self.kvcache_manager.offload_engine @@ -792,13 +793,16 @@ class ModelRunner: input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) - # 使用 Decode区 (slot 0) 写入新 KV + # Use Decode region (slot 0) to write new KV decode_slot = offload_engine.decode_slot # = 0 pos_in_block = (len(seq) - 1) % self.block_size slot = decode_slot * self.block_size + pos_in_block slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + # Get decode start position for accumulated token tracking + decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq) + # Set up context for chunked decode set_context( is_prefill=False, @@ -808,17 +812,22 @@ class ModelRunner: offload_engine=self.kvcache_manager, chunked_seq=seq, decode_pos_in_block=pos_in_block, + decode_start_pos_in_block=decode_start_pos, ) # Run model forward pass logits = self.run_model(input_ids, positions, is_prefill=False) reset_context() - # Offload new KV from Decode区 to CPU - last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq) - if last_cpu_block >= 0: - offload_engine.offload_decode_slot(last_cpu_block) - offload_engine.wait_all_offload_done() + # Only offload when block is full (pos_in_block == block_size - 1) + # This avoids unnecessary offloading on every decode step + if pos_in_block == self.block_size - 1: + last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq) + if last_cpu_block >= 0: + offload_engine.offload_decode_slot(last_cpu_block) + offload_engine.wait_all_offload_done() + # Reset decode start position for next block + self.kvcache_manager.reset_decode_start_pos(seq) # Sample temperatures = self.prepare_sample(seqs) if self.rank == 0 else None diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index 691cbbf..019c27b 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -95,16 +95,16 @@ class HybridKVCacheManager(KVCacheManager): num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage) block_size: Tokens per block policy: Eviction policy (default: LRU) - cpu_primary: If True, use CPU as primary storage with 三区域 GPU buffer. + cpu_primary: If True, use CPU as primary storage with three-region GPU buffer. If False, use GPU as primary with CPU as overflow (legacy mode). - num_prefetch_blocks: Number of prefetch blocks for 三区域 GPU buffer design + num_prefetch_blocks: Number of prefetch blocks for three-region GPU buffer design """ self._block_size = block_size self.num_gpu_slots = num_gpu_slots self.num_cpu_blocks = num_cpu_blocks self.total_blocks = num_gpu_slots + num_cpu_blocks - self.cpu_primary = cpu_primary # 三区域 mode flag - self.num_prefetch_blocks = num_prefetch_blocks # 三区域设计参数 + self.cpu_primary = cpu_primary # Three-region mode flag + self.num_prefetch_blocks = num_prefetch_blocks # Three-region design parameter # Eviction policy self.policy = policy or LRUPolicy() @@ -138,6 +138,10 @@ class HybridKVCacheManager(KVCacheManager): # Track blocks that have been prefilled (KV written) for chunked prefill self.prefilled_blocks: Set[int] = set() # logical_ids + # Track decode starting position within block (for batched offload optimization) + # Key: sequence id, Value: starting position where decode began in current block + self._decode_start_pos: Dict[int, int] = {} + @property def block_size(self) -> int: return self._block_size @@ -337,11 +341,11 @@ class HybridKVCacheManager(KVCacheManager): """ assert not seq.block_table, "Sequence already has blocks" - # Ping-Pong模式:所有blocks都分配到CPU + # Three-region mode: all blocks are allocated to CPU if self.cpu_primary: return self.allocate_cpu_only(seq) - # Legacy模式:GPU为主,CPU为overflow + # Legacy mode: GPU as primary, CPU as overflow h = -1 cache_miss = False @@ -467,7 +471,7 @@ class HybridKVCacheManager(KVCacheManager): block.token_ids = [] if self.cpu_primary: - # Ping-Pong模式:新block分配到CPU + # Three-region mode: new block allocated to CPU if not self.free_cpu_blocks: raise RuntimeError("No free CPU blocks for decode") cpu_block_id = self.free_cpu_blocks.popleft() @@ -476,7 +480,7 @@ class HybridKVCacheManager(KVCacheManager): block.gpu_slot = -1 self.cpu_block_to_logical[cpu_block_id] = logical_id else: - # Legacy模式:新block分配到GPU + # Legacy mode: new block allocated to GPU gpu_slot = self._allocate_gpu_slot() block.location = BlockLocation.GPU block.gpu_slot = gpu_slot @@ -1021,22 +1025,22 @@ class HybridKVCacheManager(KVCacheManager): break return pos - # ========== Ping-Pong 双缓冲支持 ========== + # ========== Three-region double buffering support ========== def allocate_cpu_only(self, seq: Sequence) -> None: """ - 为序列分配 CPU blocks(用于 Ping-Pong 模式)。 + Allocate CPU blocks for sequence (for three-region mode). - 与 allocate() 不同,这里所有 blocks 都分配到 CPU, - GPU 只用作工作缓冲区。 + Unlike allocate(), here all blocks are allocated to CPU, + GPU is only used as working buffer. Args: - seq: 要分配的序列 + seq: Sequence to allocate """ assert not seq.block_table, "Sequence already has blocks" for i in range(seq.num_blocks): - # 分配 CPU block + # Allocate CPU block if not self.free_cpu_blocks: raise RuntimeError( f"No free CPU blocks. Need {seq.num_blocks}, " @@ -1045,7 +1049,7 @@ class HybridKVCacheManager(KVCacheManager): cpu_block_id = self.free_cpu_blocks.popleft() - # 分配逻辑块 + # Allocate logical block logical_id = self.free_logical_ids.popleft() block = self.logical_blocks[logical_id] block.ref_count = 1 @@ -1058,13 +1062,13 @@ class HybridKVCacheManager(KVCacheManager): def get_cpu_block_table(self, seq: Sequence) -> List[int]: """ - 获取序列的 CPU block ID 列表。 + Get CPU block ID list for sequence. Args: - seq: 序列 + seq: Sequence Returns: - CPU block IDs 列表,按序列顺序 + List of CPU block IDs in sequence order """ cpu_blocks = [] for logical_id in seq.block_table: @@ -1072,20 +1076,20 @@ class HybridKVCacheManager(KVCacheManager): if block.location == BlockLocation.CPU: cpu_blocks.append(block.cpu_block_id) else: - # 如果 block 在 GPU 上,它应该有一个对应的 CPU block - # 在 Ping-Pong 模式下,所有数据最终都在 CPU 上 + # If block is on GPU, it should have a corresponding CPU block + # In three-region mode, all data ultimately resides on CPU raise RuntimeError( f"Block {logical_id} not on CPU (location={block.location}). " - f"In Ping-Pong mode, all blocks should be on CPU." + f"In three-region mode, all blocks should be on CPU." ) return cpu_blocks def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]: """ - 获取序列的所有 CPU blocks 及其逻辑 ID。 + Get all CPU blocks and their logical IDs for sequence. Args: - seq: 序列 + seq: Sequence Returns: (cpu_block_ids, logical_ids) @@ -1101,13 +1105,13 @@ class HybridKVCacheManager(KVCacheManager): def allocate_next_cpu_block(self, seq: Sequence) -> int: """ - 为序列分配下一个 CPU block(用于 decode 时新 token)。 + Allocate next CPU block for sequence (for new token during decode). Args: - seq: 序列 + seq: Sequence Returns: - 新分配的 CPU block ID + Newly allocated CPU block ID """ if not self.free_cpu_blocks: raise RuntimeError("No free CPU blocks") @@ -1128,15 +1132,15 @@ class HybridKVCacheManager(KVCacheManager): def get_last_cpu_block(self, seq: Sequence) -> int: """ - 获取序列最后一个 block 的 CPU block ID。 + Get CPU block ID of the last block in sequence. - 如果最后一个 block 不在 CPU 上,返回 -1。 + Returns -1 if the last block is not on CPU. Args: - seq: 序列 + seq: Sequence Returns: - CPU block ID,如果不在 CPU 上则返回 -1 + CPU block ID, or -1 if not on CPU """ if not seq.block_table: return -1 @@ -1150,19 +1154,65 @@ class HybridKVCacheManager(KVCacheManager): def get_write_slot_for_pingpong(self, seq: Sequence) -> int: """ - 获取三区域 decode 时新 KV 写入的 GPU slot。 + Get GPU slot for writing new KV during three-region decode. - 在三区域设计中,永远使用 Decode区 (slot 0) 写入新 KV。 - 这样可以避免与 Compute/Prefetch区 的加载操作冲突。 + In three-region design, always use Decode region (slot 0) to write new KV. + This avoids conflicts with Compute/Prefetch region loading operations. Args: - seq: 序列 + seq: Sequence Returns: - GPU slot ID (永远是 decode_slot = 0) + GPU slot ID (always decode_slot = 0) """ return self.offload_engine.decode_slot + def get_decode_start_pos(self, seq: Sequence) -> int: + """ + Get the starting position within block where decode tokens began. + + This is used for batched offload optimization - we need to attend to all + accumulated tokens in decode slot, not just the current one. + + Args: + seq: Sequence + + Returns: + Starting position within block (0 to block_size-1) + """ + seq_id = id(seq) + if seq_id not in self._decode_start_pos: + # First decode step - compute starting position + # After prefill, the last block has some tokens filled + # Decode starts at the next position + prefill_len = len(seq) - 1 # Current len includes the new decode token + self._decode_start_pos[seq_id] = prefill_len % self._block_size + return self._decode_start_pos[seq_id] + + def reset_decode_start_pos(self, seq: Sequence) -> None: + """ + Reset decode start position for sequence. + + Called when block is full and offloaded - next decode starts at position 0. + + Args: + seq: Sequence + """ + seq_id = id(seq) + self._decode_start_pos[seq_id] = 0 + + def clear_decode_tracking(self, seq: Sequence) -> None: + """ + Clear decode position tracking for sequence. + + Called when sequence is deallocated. + + Args: + seq: Sequence + """ + seq_id = id(seq) + self._decode_start_pos.pop(seq_id, None) + def __repr__(self) -> str: return ( f"HybridKVCacheManager(\n" diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 6a3e9b3..582325e 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -65,44 +65,44 @@ class OffloadEngine: self.kv_dim = num_kv_heads * head_dim self.block_numel = block_size * self.kv_dim - # ========== 三区域 GPU Buffer 配置 ========== - # 约束检查 + # ========== Three-region GPU Buffer configuration ========== + # Constraint checks assert num_gpu_blocks >= 3, \ - f"至少需要3个GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}" + f"Need at least 3 GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}" assert num_prefetch_blocks >= 1, \ - f"至少需要1个prefetch block, got {num_prefetch_blocks}" + f"Need at least 1 prefetch block, got {num_prefetch_blocks}" assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \ - f"GPU blocks不足: 需要 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}" + f"Insufficient GPU blocks: need 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}" - # 三区域配置 - # Decode区: [0] - 固定1个block用于写入新KV + # Three-region configuration + # Decode region: [0] - Fixed 1 block for writing new KV self.decode_slot = 0 - # Compute区: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1] + # Compute region: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1] compute_start = 1 compute_end = num_gpu_blocks - num_prefetch_blocks self.compute_slots = list(range(compute_start, compute_end)) self.num_compute_blocks = len(self.compute_slots) - # Prefetch区: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1] + # Prefetch region: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1] prefetch_start = compute_end self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks)) self.num_prefetch_blocks = num_prefetch_blocks self.num_gpu_slots = num_gpu_blocks # alias - # 保留旧的ping/pong属性以兼容(后续会移除) + # Keep old ping/pong attributes for compatibility (will be removed later) self.ping_size = self.num_compute_blocks self.pong_size = self.num_prefetch_blocks self.ping_slots = self.compute_slots.copy() self.pong_slots = self.prefetch_slots.copy() - logger.info(f"三区域 GPU Buffer: decode_slot={self.decode_slot}, " + logger.info(f"Three-region GPU Buffer: decode_slot={self.decode_slot}, " f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}") # ========== Fixed-address GPU KV cache ========== # Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] - # 使用 zeros 初始化以避免未初始化内存问题 + # Use zeros initialization to avoid uninitialized memory issues self.k_cache_gpu = torch.zeros( num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim, dtype=dtype, device="cuda" @@ -140,15 +140,15 @@ class OffloadEngine: self.compute_stream = torch.cuda.current_stream() self._stream_idx = 0 - # ========== 三区域专用 stream 和事件 ========== - self.transfer_stream_main = torch.cuda.Stream() # 主传输stream + # ========== Three-region dedicated stream and events ========== + self.transfer_stream_main = torch.cuda.Stream() # Main transfer stream - # 同步事件 - 三区域加载完成 + # Sync events - three-region loading completion self.compute_ready = torch.cuda.Event() self.prefetch_ready = torch.cuda.Event() self.decode_offload_done = torch.cuda.Event() - # 保留旧的ping/pong事件以兼容(后续会移除) + # Keep old ping/pong events for compatibility (will be removed later) self.pingpong_stream = self.transfer_stream_main self.ping_ready = self.compute_ready self.pong_ready = self.prefetch_ready @@ -568,20 +568,20 @@ class OffloadEngine: f" kv_heads={self.num_kv_heads},\n" f" head_dim={self.head_dim},\n" f" dtype={self.dtype},\n" - f" 三区域: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n" + f" three_regions: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n" f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n" f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n" f")" ) - # ========== Ping-Pong 双缓冲方法 ========== + # ========== Ping-Pong double buffering methods ========== def load_to_ping(self, cpu_block_ids: List[int]) -> None: """ - 异步加载CPU blocks到Ping buffer。 + Async load CPU blocks to Ping buffer. Args: - cpu_block_ids: 要加载的CPU block IDs列表 + cpu_block_ids: List of CPU block IDs to load """ if not cpu_block_ids: self.ping_ready.record(self.pingpong_stream) @@ -594,7 +594,7 @@ class OffloadEngine: for i in range(num_to_load): cpu_id = cpu_block_ids[i] gpu_slot = self.ping_slots[i] - # 所有层一起复制 + # Copy all layers together self.k_cache_gpu[:, gpu_slot].copy_( self.k_cache_cpu[:, cpu_id], non_blocking=True ) @@ -605,10 +605,10 @@ class OffloadEngine: def load_to_pong(self, cpu_block_ids: List[int]) -> None: """ - 异步加载CPU blocks到Pong buffer。 + Async load CPU blocks to Pong buffer. Args: - cpu_block_ids: 要加载的CPU block IDs列表 + cpu_block_ids: List of CPU block IDs to load """ if not cpu_block_ids: self.pong_ready.record(self.pingpong_stream) @@ -630,11 +630,11 @@ class OffloadEngine: self.pong_ready.record(self.pingpong_stream) def wait_ping(self) -> None: - """等待Ping buffer加载完成。""" + """Wait for Ping buffer loading to complete.""" self.compute_stream.wait_event(self.ping_ready) def wait_pong(self) -> None: - """等待Pong buffer加载完成。""" + """Wait for Pong buffer loading to complete.""" self.compute_stream.wait_event(self.pong_ready) def offload_buffer_to_cpu( @@ -643,11 +643,11 @@ class OffloadEngine: cpu_block_ids: List[int], ) -> None: """ - 异步将buffer中的KV offload到CPU。 + Async offload KV from buffer to CPU. Args: - buffer: "ping" 或 "pong" - cpu_block_ids: 目标CPU block IDs列表 + buffer: "ping" or "pong" + cpu_block_ids: Target CPU block IDs list """ slots = self.ping_slots if buffer == "ping" else self.pong_slots event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done @@ -660,7 +660,7 @@ class OffloadEngine: logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}") with torch.cuda.stream(self.pingpong_stream): - # 等待计算完成 + # Wait for compute to complete self.pingpong_stream.wait_stream(self.compute_stream) for i in range(num_to_offload): @@ -680,11 +680,11 @@ class OffloadEngine: cpu_block_id: int, ) -> None: """ - 异步将单个GPU slot的KV offload到CPU。 + Async offload a single GPU slot's KV to CPU. Args: gpu_slot: GPU slot ID - cpu_block_id: 目标CPU block ID + cpu_block_id: Target CPU block ID """ logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]") @@ -698,15 +698,15 @@ class OffloadEngine: ) def wait_ping_offload_done(self) -> None: - """等待Ping buffer offload完成。""" + """Wait for Ping buffer offload to complete.""" self.compute_stream.wait_event(self.ping_offload_done) def wait_pong_offload_done(self) -> None: - """等待Pong buffer offload完成。""" + """Wait for Pong buffer offload to complete.""" self.compute_stream.wait_event(self.pong_offload_done) def wait_all_offload_done(self) -> None: - """等待所有offload完成。""" + """Wait for all offload operations to complete.""" self.pingpong_stream.synchronize() def get_kv_for_ping_slots( @@ -715,14 +715,14 @@ class OffloadEngine: num_slots: int, ) -> Tuple[Tensor, Tensor]: """ - 获取Ping buffer中指定数量slots的KV。 + Get KV for specified number of slots in Ping buffer. Args: - layer_id: 层ID - num_slots: 需要的slot数量 + layer_id: Layer ID + num_slots: Number of slots needed Returns: - (k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim] + (k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim] """ slots = self.ping_slots[:num_slots] k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim] @@ -738,14 +738,14 @@ class OffloadEngine: num_slots: int, ) -> Tuple[Tensor, Tensor]: """ - 获取Pong buffer中指定数量slots的KV。 + Get KV for specified number of slots in Pong buffer. Args: - layer_id: 层ID - num_slots: 需要的slot数量 + layer_id: Layer ID + num_slots: Number of slots needed Returns: - (k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim] + (k_cache, v_cache), shape: [1, num_slots * block_size, kv_heads, head_dim] """ slots = self.pong_slots[:num_slots] k = self.k_cache_gpu[layer_id, slots] @@ -760,14 +760,14 @@ class OffloadEngine: gpu_slots: List[int], ) -> Tuple[Tensor, Tensor]: """ - 获取指定GPU slots的KV。 + Get KV for specified GPU slots. Args: - layer_id: 层ID - gpu_slots: GPU slot IDs列表 + layer_id: Layer ID + gpu_slots: List of GPU slot IDs Returns: - (k_cache, v_cache),shape: [1, len(slots) * block_size, kv_heads, head_dim] + (k_cache, v_cache), shape: [1, len(slots) * block_size, kv_heads, head_dim] """ if not gpu_slots: return None, None @@ -777,14 +777,14 @@ class OffloadEngine: v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) return k, v - # ========== 三区域 GPU Buffer 方法 ========== + # ========== Three-region GPU Buffer methods ========== def load_to_compute(self, cpu_block_ids: List[int]) -> None: """ - 异步加载CPU blocks到Compute区。 + Async load CPU blocks to Compute region. Args: - cpu_block_ids: 要加载的CPU block IDs列表 + cpu_block_ids: List of CPU block IDs to load """ if not cpu_block_ids: self.compute_ready.record(self.transfer_stream_main) @@ -797,7 +797,7 @@ class OffloadEngine: for i in range(num_to_load): cpu_id = cpu_block_ids[i] gpu_slot = self.compute_slots[i] - # 所有层一起复制 + # Copy all layers together self.k_cache_gpu[:, gpu_slot].copy_( self.k_cache_cpu[:, cpu_id], non_blocking=True ) @@ -808,10 +808,10 @@ class OffloadEngine: def load_to_prefetch(self, cpu_block_ids: List[int]) -> None: """ - 异步加载CPU blocks到Prefetch区。 + Async load CPU blocks to Prefetch region. Args: - cpu_block_ids: 要加载的CPU block IDs列表 + cpu_block_ids: List of CPU block IDs to load """ if not cpu_block_ids: self.prefetch_ready.record(self.transfer_stream_main) @@ -833,25 +833,25 @@ class OffloadEngine: self.prefetch_ready.record(self.transfer_stream_main) def wait_compute(self) -> None: - """等待Compute区加载完成。""" + """Wait for Compute region loading to complete.""" self.compute_stream.wait_event(self.compute_ready) def wait_prefetch(self) -> None: - """等待Prefetch区加载完成。""" + """Wait for Prefetch region loading to complete.""" self.compute_stream.wait_event(self.prefetch_ready) def swap_compute_prefetch(self) -> None: - """交换Compute区和Prefetch区的角色。""" + """Swap roles of Compute region and Prefetch region.""" self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots - # 同时更新旧的ping/pong slots以保持兼容 + # Also update old ping/pong slots for compatibility self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots def offload_decode_slot(self, cpu_block_id: int) -> None: """ - 将Decode区的KV offload到CPU。 + Offload KV from Decode region to CPU. Args: - cpu_block_id: 目标CPU block ID + cpu_block_id: Target CPU block ID """ logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]") @@ -866,7 +866,7 @@ class OffloadEngine: self.decode_offload_done.record(self.transfer_stream_main) def wait_decode_offload(self) -> None: - """等待Decode区offload完成。""" + """Wait for Decode region offload to complete.""" self.compute_stream.wait_event(self.decode_offload_done) def get_kv_for_compute( @@ -875,14 +875,14 @@ class OffloadEngine: num_blocks: int, ) -> Tuple[Tensor, Tensor]: """ - 获取Compute区中指定数量blocks的KV。 + Get KV for specified number of blocks in Compute region. Args: - layer_id: 层ID - num_blocks: 需要的block数量 + layer_id: Layer ID + num_blocks: Number of blocks needed Returns: - (k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim] + (k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim] """ slots = self.compute_slots[:num_blocks] k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim] @@ -898,14 +898,14 @@ class OffloadEngine: num_blocks: int, ) -> Tuple[Tensor, Tensor]: """ - 获取Prefetch区中指定数量blocks的KV。 + Get KV for specified number of blocks in Prefetch region. Args: - layer_id: 层ID - num_blocks: 需要的block数量 + layer_id: Layer ID + num_blocks: Number of blocks needed Returns: - (k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim] + (k_cache, v_cache), shape: [1, num_blocks * block_size, kv_heads, head_dim] """ slots = self.prefetch_slots[:num_blocks] k = self.k_cache_gpu[layer_id, slots] @@ -920,14 +920,14 @@ class OffloadEngine: pos_in_block: int, ) -> Tuple[Tensor, Tensor]: """ - 获取Decode区指定位置的KV(用于decode时的新token)。 + Get KV at specified position in Decode region (for new token during decode). Args: - layer_id: 层ID - pos_in_block: token在block内的位置 (0 to block_size-1) + layer_id: Layer ID + pos_in_block: Token position within block (0 to block_size-1) Returns: - (k_cache, v_cache),shape: [1, 1, kv_heads, head_dim] + (k_cache, v_cache), shape: [1, 1, kv_heads, head_dim] """ k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim] v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] @@ -935,12 +935,36 @@ class OffloadEngine: v = v.unsqueeze(0) return k, v - def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None: + def get_kv_for_decode_slot_accumulated( + self, + layer_id: int, + num_tokens: int, + ) -> Tuple[Tensor, Tensor]: """ - 将Compute区的KV offload到CPU。 + Get accumulated KV in Decode region (all tokens from position 0 to num_tokens-1). + + Used when batching decode offloads - attend to all accumulated tokens, + not just the current one. Args: - cpu_block_ids: 目标CPU block IDs列表 + layer_id: Layer ID + num_tokens: Number of accumulated tokens (1 to block_size) + + Returns: + (k_cache, v_cache), shape: [1, num_tokens, kv_heads, head_dim] + """ + k = self.k_cache_gpu[layer_id, self.decode_slot, :num_tokens] # [num_tokens, heads, dim] + v = self.v_cache_gpu[layer_id, self.decode_slot, :num_tokens] + k = k.unsqueeze(0) # [1, num_tokens, heads, dim] + v = v.unsqueeze(0) + return k, v + + def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None: + """ + Offload KV from Compute region to CPU. + + Args: + cpu_block_ids: Target CPU block IDs list """ if not cpu_block_ids: return @@ -949,7 +973,7 @@ class OffloadEngine: logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}") with torch.cuda.stream(self.transfer_stream_main): - # 等待计算完成 + # Wait for compute to complete self.transfer_stream_main.wait_stream(self.compute_stream) for i in range(num_to_offload): diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 92b0fc0..ab518f8 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -100,16 +100,16 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute attention with 三区域 GPU buffer for chunked prefill. + Compute attention with three-region GPU buffer for chunked prefill. For chunked prefill: - 1. Load previous KV from CPU using Compute/Prefetch区 (if any previous chunks) + 1. Load previous KV from CPU using Compute/Prefetch region (if any previous chunks) 2. Compute attention against previous KV chunks (no causal mask) 3. Compute attention against current chunk's KV (causal) 4. Merge all results using online softmax - 三区域设计保证:当前chunk的KV在Compute区,previous KV从CPU加载到Prefetch区, - 不会发生写入和加载区域重叠的问题。 + Three-region design guarantees: current chunk's KV is in Compute region, previous KV is loaded + from CPU to Prefetch region, so write and load regions never overlap. """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs @@ -122,7 +122,7 @@ class Attention(nn.Module): o_acc = None lse_acc = None - # Load previous KV from CPU using Compute/Prefetch区 + # Load previous KV from CPU using Compute/Prefetch region # Note: context.offload_engine is actually HybridKVCacheManager kvcache_manager = context.offload_engine seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None @@ -133,12 +133,12 @@ class Attention(nn.Module): if cpu_block_table: offload_engine = kvcache_manager.offload_engine - # 使用 Prefetch区 来加载 previous KV(不会与当前 Compute区 冲突) + # Use Prefetch region to load previous KV (won't conflict with current Compute region) prefetch_size = offload_engine.num_prefetch_blocks num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size - use_compute = True # 交替使用 Compute区 和 Prefetch区 + use_compute = True # Alternate between Compute region and Prefetch region - # 首先将 previous KV 加载到 Prefetch区 + # First load previous KV to Prefetch region # Only layer 0 triggers the load (loads ALL layers at once) first_chunk_end = min(prefetch_size, len(cpu_block_table)) first_chunk_ids = cpu_block_table[:first_chunk_end] @@ -157,14 +157,14 @@ class Attention(nn.Module): next_end = min(next_start + prefetch_size, len(cpu_block_table)) next_chunk_ids = cpu_block_table[next_start:next_end] if use_compute: - # 当前在 Prefetch区,下一个加载到 Compute区(如果有空间) - # 注意:Compute区 此时已写入当前chunk的KV,不能覆盖 - # 所以这里我们使用简单的同步策略:等待当前完成后再加载 - pass # 简化版本:不进行双缓冲,只用 Prefetch区 + # Currently in Prefetch region, next load to Compute region (if space available) + # Note: Compute region already has current chunk's KV written, cannot overwrite + # So here we use simple sync strategy: wait for current to complete before loading + pass # Simplified version: no double buffering, only use Prefetch region else: offload_engine.load_to_prefetch(next_chunk_ids) - # Wait for Prefetch区 and get KV + # Wait for Prefetch region and get KV offload_engine.wait_prefetch() prev_k, prev_v = offload_engine.get_kv_for_prefetch( self.layer_id, num_blocks_in_chunk @@ -185,7 +185,7 @@ class Attention(nn.Module): else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - # Load next chunk to Prefetch区 (if exists) + # Load next chunk to Prefetch region (if exists) if chunk_idx + 1 < num_chunks and self.layer_id == 0: next_start = end next_end = min(next_start + prefetch_size, len(cpu_block_table)) @@ -218,16 +218,16 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute decode attention with 三区域 GPU buffer. + Compute decode attention with three-region GPU buffer. - All KV is stored on CPU. Uses Compute区 buffer on GPU: - 1. Load chunk to Compute区 + All KV is stored on CPU. Uses Compute region buffer on GPU: + 1. Load chunk to Compute region 2. Compute attention 3. Repeat for all chunks - 4. Finally, attend to Decode区 (slot 0) which contains the new token's KV + 4. Finally, attend to Decode region (slot 0) which contains the new token's KV 5. Merge all attention outputs using online softmax (LSE) - 关键:新token的KV在Decode区(slot 0),不会被Compute区的加载覆盖。 + Key: new token's KV is in Decode region (slot 0), won't be overwritten by Compute region loading. """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs @@ -246,10 +246,10 @@ class Attention(nn.Module): if not cpu_block_table: raise RuntimeError("Chunked decode attention failed: no CPU blocks available") - # Get the actual offload_engine for 三区域 operations + # Get the actual offload_engine for three-region operations offload_engine = kvcache_manager.offload_engine - # Calculate chunk info using Compute区 + # Calculate chunk info using Compute region compute_size = offload_engine.num_compute_blocks num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size @@ -262,12 +262,12 @@ class Attention(nn.Module): num_blocks_in_chunk = end - start chunk_ids = cpu_block_table[start:end] - # Load this chunk to Compute区 + # Load this chunk to Compute region # Only layer 0 triggers the load (loads ALL layers at once) if self.layer_id == 0: offload_engine.load_to_compute(chunk_ids) - # Wait for Compute区 to be ready and get KV + # Wait for Compute region to be ready and get KV offload_engine.wait_compute() k_chunk, v_chunk = offload_engine.get_kv_for_compute( self.layer_id, num_blocks_in_chunk @@ -286,21 +286,31 @@ class Attention(nn.Module): else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) - # Now attend to Decode区 (contains the new token's KV) - # This is the token being decoded - only 1 token at position pos_in_block + # Now attend to Decode region (contains accumulated decode tokens) + # When batching offloads, decode slot accumulates multiple tokens + # from decode_start_pos_in_block to decode_pos_in_block (inclusive) pos_in_block = context.decode_pos_in_block - decode_k, decode_v = offload_engine.get_kv_for_decode_slot(self.layer_id, pos_in_block) - decode_o, decode_lse = flash_attn_with_lse( - q_batched, decode_k, decode_v, - softmax_scale=self.scale, - causal=False, - ) + start_pos = context.decode_start_pos_in_block + num_accumulated = pos_in_block - start_pos + 1 - # Merge with accumulated - if o_acc is None: - o_acc = decode_o - else: - o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse) + if num_accumulated > 0: + # Get accumulated KV in decode slot [start_pos : pos_in_block+1] + decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] + decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1] + decode_k = decode_k.unsqueeze(0) # [1, num_tokens, heads, dim] + decode_v = decode_v.unsqueeze(0) + + decode_o, decode_lse = flash_attn_with_lse( + q_batched, decode_k, decode_v, + softmax_scale=self.scale, + causal=False, + ) + + # Merge with accumulated + if o_acc is None: + o_acc = decode_o + else: + o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse) if o_acc is None: raise RuntimeError("Chunked decode attention failed: no KV available") diff --git a/nanovllm/layers/layernorm.py b/nanovllm/layers/layernorm.py index 71bf419..b86b1f5 100755 --- a/nanovllm/layers/layernorm.py +++ b/nanovllm/layers/layernorm.py @@ -18,6 +18,8 @@ class RMSNorm(nn.Module): self, x: torch.Tensor, ) -> torch.Tensor: + # Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch + # Callers should reshape 3D tensors to 2D before calling orig_dtype = x.dtype x = x.float() var = x.pow(2).mean(dim=-1, keepdim=True) @@ -31,6 +33,7 @@ class RMSNorm(nn.Module): x: torch.Tensor, residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: + # Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch orig_dtype = x.dtype x = x.float().add_(residual.float()) residual = x.to(orig_dtype) diff --git a/nanovllm/models/qwen3.py b/nanovllm/models/qwen3.py index 5d39e0b..6298d8b 100755 --- a/nanovllm/models/qwen3.py +++ b/nanovllm/models/qwen3.py @@ -79,8 +79,12 @@ class Qwen3Attention(nn.Module): k = k.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim) if not self.qkv_bias: - q = self.q_norm(q) - k = self.k_norm(k) + # Reshape to 2D before RMSNorm to avoid torch.compile recompilation + # q: [num_tokens, num_heads, head_dim] -> [num_tokens * num_heads, head_dim] + # After norm, reshape back to 3D + num_tokens = q.shape[0] + q = self.q_norm(q.reshape(-1, self.head_dim)).view(num_tokens, self.num_heads, self.head_dim) + k = self.k_norm(k.reshape(-1, self.head_dim)).view(num_tokens, self.num_kv_heads, self.head_dim) q, k = self.rotary_emb(positions, q, k) o = self.attn(q, k, v) output = self.o_proj(o.flatten(1, -1)) diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index f3d0b5e..b32b573 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -27,8 +27,11 @@ class Context: prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list) # Current sequence being processed (for chunked prefill to load KV) chunked_seq: Any = None - # Position within block for decode (used for reading from Decode区) + # Position within block for decode (used for reading from Decode region) decode_pos_in_block: int = 0 + # Starting position within block where decode tokens began (for accumulated token tracking) + # Used when batching decode offloads - we need to attend to all accumulated tokens + decode_start_pos_in_block: int = 0 _CONTEXT = Context() @@ -53,6 +56,7 @@ def set_context( offload_engine=None, chunked_seq=None, decode_pos_in_block=0, + decode_start_pos_in_block=0, ): global _CONTEXT _CONTEXT = Context( @@ -70,6 +74,7 @@ def set_context( offload_engine=offload_engine, chunked_seq=chunked_seq, decode_pos_in_block=decode_pos_in_block, + decode_start_pos_in_block=decode_start_pos_in_block, ) diff --git a/tests/test_chunked_attention.py b/tests/test_chunked_attention.py index 56d1d3f..ccd6685 100644 --- a/tests/test_chunked_attention.py +++ b/tests/test_chunked_attention.py @@ -66,7 +66,7 @@ Attention mechanisms allow models to focus on relevant parts of the input. return "".join(prompt_parts) -def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64): +def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64, num_prefetch_blocks=2): """Test chunked prefill with limited GPU blocks.""" path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") @@ -75,15 +75,17 @@ def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64): print(f"=" * 60) print(f" target_input_len: ~{input_len} tokens") print(f" num_gpu_blocks: {num_gpu_blocks}") + print(f" num_prefetch_blocks: {num_prefetch_blocks}") print() llm = LLM( path, enforce_eager=False, - max_model_len=16 * 1024, - max_num_batched_tokens=16 * 1024, + max_model_len=128 * 1024, + max_num_batched_tokens=128 * 1024, enable_cpu_offload=True, num_gpu_blocks=num_gpu_blocks, + num_prefetch_blocks=num_prefetch_blocks, ) print() @@ -104,7 +106,7 @@ def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64): return outputs -def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128): +def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128, num_prefetch_blocks=2): """Test chunked decode with limited GPU blocks.""" path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") @@ -114,15 +116,17 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128): print(f" target_input_len: ~{input_len} tokens") print(f" output_len: {output_len} tokens") print(f" num_gpu_blocks: {num_gpu_blocks}") + print(f" num_prefetch_blocks: {num_prefetch_blocks}") print() llm = LLM( path, enforce_eager=False, - max_model_len=16 * 1024, - max_num_batched_tokens=16 * 1024, + max_model_len=128 * 1024, + max_num_batched_tokens=128 * 1024, enable_cpu_offload=True, num_gpu_blocks=num_gpu_blocks, + num_prefetch_blocks=num_prefetch_blocks, ) print() @@ -144,9 +148,10 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128): if __name__ == "__main__": - # Parse arguments + # Parse arguments: num_gpu_blocks input_len output_len [num_prefetch_blocks] num_gpu_blocks = int(sys.argv[1]) if len(sys.argv) > 1 else 10 input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 2048 output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 64 + num_prefetch_blocks = int(sys.argv[4]) if len(sys.argv) > 4 else 2 - test_chunked_prefill(num_gpu_blocks, input_len, output_len) \ No newline at end of file + test_chunked_prefill(num_gpu_blocks, input_len, output_len, num_prefetch_blocks) \ No newline at end of file