diff --git a/bench_offload.py b/bench_offload.py index 98e0745..97224aa 100644 --- a/bench_offload.py +++ b/bench_offload.py @@ -40,7 +40,6 @@ def main(): max_model_len=128 * 1024, max_num_batched_tokens=128 * 1024, enable_cpu_offload=True, - cpu_memory_gb=32.0, ) # Warmup diff --git a/nanovllm/config.py b/nanovllm/config.py index 124dc1e..4fbae3b 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -19,7 +19,6 @@ class Config: # CPU Offload configuration enable_cpu_offload: bool = False - cpu_memory_gb: float = 16.0 # CPU memory limit for KV cache offload_policy: str = "lru" # "lru", "fifo", or full class path num_transfer_streams: int = 4 # Number of CUDA streams for async transfers num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available) diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index ef64cfc..ccb58d3 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -10,8 +10,11 @@ from nanovllm.models.qwen3 import Qwen3ForCausalLM from nanovllm.layers.sampler import Sampler from nanovllm.utils.context import set_context, get_context, reset_context from nanovllm.utils.loader import load_model +from nanovllm.utils.logger import get_logger from nanovllm.kvcache import create_kvcache_manager, KVCacheManager +logger = get_logger("model_runner") + class ModelRunner: @@ -120,9 +123,11 @@ class ModelRunner: num_gpu_blocks = max_gpu_blocks if config.enable_cpu_offload: - # Calculate CPU blocks based on cpu_memory_gb - cpu_bytes = int(config.cpu_memory_gb * 1024**3) - num_cpu_blocks = cpu_bytes // block_bytes + # Ping-Pong设计:CPU是主存储,GPU是工作缓冲区 + # CPU blocks = 支持max_model_len所需的全部blocks(存储一个最大序列的完整KV) + # GPU blocks = Ping-Pong工作缓冲区(用户指定或自动) + num_cpu_blocks = (config.max_model_len + self.block_size - 1) // self.block_size + config.num_gpu_kvcache_blocks = num_gpu_blocks config.num_cpu_kvcache_blocks = num_cpu_blocks # For backward compatibility @@ -143,6 +148,27 @@ class ModelRunner: dtype=hf_config.torch_dtype, ) + # Log KV cache allocation info + gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2) + cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2) + total_memory_mb = gpu_memory_mb + cpu_memory_mb + + if config.enable_cpu_offload: + logger.info( + f"KV Cache allocated (Ping-Pong mode): " + f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), " + f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), " + f"Total={total_memory_mb:.1f}MB, " + f"block_size={self.block_size}, " + f"ping_size={config.num_gpu_kvcache_blocks // 2}" + ) + else: + logger.info( + f"KV Cache allocated: " + f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), " + f"block_size={self.block_size}" + ) + # Bind layer caches to attention modules and set layer_id layer_id = 0 for module in self.model.modules(): @@ -328,7 +354,16 @@ class ModelRunner: return self.model.compute_logits(graph_vars["outputs"][:bs]) def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: - # Check if chunked prefill is needed + # Check if Ping-Pong mode should be used (all blocks on CPU) + if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'): + use_pingpong = self._should_use_pingpong(seqs, is_prefill) + if use_pingpong: + if is_prefill: + return self.run_pingpong_prefill(seqs) + else: + return self.run_pingpong_decode(seqs) + + # Check if chunked prefill is needed (legacy path) if is_prefill and hasattr(self, 'kvcache_manager'): needs_chunked = any( hasattr(self.kvcache_manager, 'needs_chunked_prefill') and @@ -338,7 +373,7 @@ class ModelRunner: if needs_chunked: return self.run_chunked_prefill(seqs) - # Check if chunked decode is needed + # Check if chunked decode is needed (legacy path) if not is_prefill and hasattr(self, 'kvcache_manager'): needs_chunked = any( hasattr(self.kvcache_manager, 'needs_chunked_decode') and @@ -355,6 +390,36 @@ class ModelRunner: reset_context() return token_ids + def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool: + """ + Check if Ping-Pong mode should be used. + + Use Ping-Pong when: + - CPU offload is enabled + - There are blocks on CPU (either allocated there or offloaded) + - Sequence exceeds GPU capacity + """ + if not hasattr(self.kvcache_manager, 'offload_engine'): + return False + + for seq in seqs: + if not seq.block_table: + continue # Skip warmup sequences + + # Check if any blocks are on CPU + cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq) + if cpu_blocks: + # Has CPU blocks - use Ping-Pong + return True + + # Check if sequence needs more blocks than GPU can hold + ping_size = self.kvcache_manager.offload_engine.ping_size + if seq.num_blocks > ping_size: + # Needs chunked processing + return True + + return False + def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]: """ Run prefill in chunks when sequences exceed GPU capacity. @@ -543,6 +608,210 @@ class ModelRunner: return input_ids, positions + def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]: + """ + Run prefill with Ping-Pong dual buffer (CPU is primary storage). + + Flow: + 1. All blocks are allocated to CPU (primary storage) + 2. Process tokens in chunks using Ping/Pong GPU buffers + 3. After each chunk, offload from GPU to CPU + 4. Alternate between Ping and Pong buffers + """ + import sys + + assert len(seqs) == 1, "Ping-Pong prefill only supports single sequence" + seq = seqs[0] + + offload_engine = self.kvcache_manager.offload_engine + ping_size = offload_engine.ping_size + tokens_per_chunk = ping_size * self.block_size + + total_tokens = len(seq) + print(f"[Ping-Pong Prefill] Starting: {total_tokens} tokens, " + f"ping_size={ping_size} blocks, chunk={tokens_per_chunk} tokens", + file=sys.stderr) + + current_buffer = "ping" + chunk_num = 0 + logits = None + processed_tokens = 0 + + # Get CPU block table for offload targets + cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq) + + while processed_tokens < total_tokens: + chunk_num += 1 + chunk_start = processed_tokens + chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens) + chunk_tokens = chunk_end - chunk_start + + # Calculate which CPU blocks this chunk covers + start_block_idx = chunk_start // self.block_size + end_block_idx = (chunk_end + self.block_size - 1) // self.block_size + num_blocks = end_block_idx - start_block_idx + + print(f"[Ping-Pong Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, " + f"blocks {start_block_idx}-{end_block_idx-1}, buffer={current_buffer}", + file=sys.stderr) + + # Get GPU slots for this chunk (Ping or Pong buffer) + if current_buffer == "ping": + gpu_slots = offload_engine.ping_slots[:num_blocks] + else: + gpu_slots = offload_engine.pong_slots[:num_blocks] + + # Prepare inputs + input_ids, positions = self._prepare_pingpong_chunk( + seq, chunk_start, chunk_end, gpu_slots, start_block_idx + ) + + if input_ids.numel() == 0: + break + + # Run model forward + logits = self.run_model(input_ids, positions, is_prefill=True) + reset_context() + + # Mark blocks as prefilled + for i in range(start_block_idx, min(end_block_idx, len(seq.block_table))): + logical_id = seq.block_table[i] + self.kvcache_manager.prefilled_blocks.add(logical_id) + + # Offload this chunk from GPU to CPU (async) + chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx] + offload_engine.offload_buffer_to_cpu(current_buffer, chunk_cpu_blocks) + + # Switch buffer for next chunk + if current_buffer == "ping": + offload_engine.wait_ping_offload_done() + current_buffer = "pong" + else: + offload_engine.wait_pong_offload_done() + current_buffer = "ping" + + processed_tokens = chunk_end + + # Wait for all offloads to complete + offload_engine.wait_all_offload_done() + + print(f"[Ping-Pong Prefill] Complete: {chunk_num} chunks", file=sys.stderr) + + # Sample from last logits + temperatures = self.prepare_sample(seqs) if self.rank == 0 else None + if logits is not None: + token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None + else: + token_ids = [0] if self.rank == 0 else None + + return token_ids + + def _prepare_pingpong_chunk( + self, + seq: Sequence, + chunk_start: int, + chunk_end: int, + gpu_slots: list[int], + start_block_idx: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Prepare inputs for a Ping-Pong prefill chunk.""" + # Input tokens for this chunk + input_ids = seq[chunk_start:chunk_end] + positions = list(range(chunk_start, chunk_end)) + + # Create slot mapping pointing to GPU slots + slot_mapping = [] + num_tokens = chunk_end - chunk_start + + token_idx = 0 + for i, gpu_slot in enumerate(gpu_slots): + block_idx = start_block_idx + i + block_start = block_idx * self.block_size + block_end = min(block_start + self.block_size, len(seq)) + + # How many tokens in this block for this chunk + overlap_start = max(chunk_start, block_start) + overlap_end = min(chunk_end, block_end) + + for pos in range(overlap_start, overlap_end): + pos_in_block = pos % self.block_size + slot = gpu_slot * self.block_size + pos_in_block + slot_mapping.append(slot) + + # Convert to tensors + input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + + # Set up context for chunked prefill + seqlen = num_tokens + cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + + set_context( + is_prefill=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + slot_mapping=slot_mapping, + is_chunked_prefill=True, + offload_engine=self.kvcache_manager, + chunked_seq=seq, + ) + + return input_ids, positions + + def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]: + """ + Run decode with Ping-Pong dual buffer. + + All KV is on CPU. Uses Ping-Pong to load KV chunks and compute attention. + New token's KV is written to GPU then offloaded to CPU. + """ + assert len(seqs) == 1, "Ping-Pong decode only supports single sequence" + seq = seqs[0] + + # Prepare inputs + input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + + # Get write slot for new KV (will use last slot of the buffer used for final chunk) + write_slot = self.kvcache_manager.get_write_slot_for_pingpong(seq) + + # Calculate position in block for slot mapping + last_block_idx = seq.num_blocks - 1 + pos_in_block = (len(seq) - 1) % self.block_size + slot = write_slot * self.block_size + pos_in_block + slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + + # Set up context for chunked decode + set_context( + is_prefill=False, + slot_mapping=slot_mapping, + context_lens=context_len, + is_chunked_prefill=True, # Use chunked attention path + offload_engine=self.kvcache_manager, + chunked_seq=seq, + ) + + # Run model forward pass + logits = self.run_model(input_ids, positions, is_prefill=False) + reset_context() + + # Offload new KV from write_slot to CPU + last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq) + if last_cpu_block >= 0: + self.kvcache_manager.offload_engine.offload_slot_to_cpu(write_slot, last_cpu_block) + self.kvcache_manager.offload_engine.wait_all_offload_done() + + # Sample + temperatures = self.prepare_sample(seqs) if self.rank == 0 else None + token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None + + return token_ids + @torch.inference_mode() def capture_cudagraph(self): config = self.config diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index b40c51a..1d97887 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -81,20 +81,24 @@ class HybridKVCacheManager(KVCacheManager): num_cpu_blocks: int, block_size: int, policy: Optional[EvictionPolicy] = None, + cpu_primary: bool = True, ): """ Initialize hybrid manager. Args: num_gpu_slots: Number of GPU buffer slots (working set) - num_cpu_blocks: Number of CPU pool blocks (overflow) + num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage) block_size: Tokens per block policy: Eviction policy (default: LRU) + cpu_primary: If True, use CPU as primary storage with Ping-Pong GPU buffer. + If False, use GPU as primary with CPU as overflow (legacy mode). """ self._block_size = block_size self.num_gpu_slots = num_gpu_slots self.num_cpu_blocks = num_cpu_blocks self.total_blocks = num_gpu_slots + num_cpu_blocks + self.cpu_primary = cpu_primary # Ping-Pong mode flag # Eviction policy self.policy = policy or LRUPolicy() @@ -321,12 +325,16 @@ class HybridKVCacheManager(KVCacheManager): """ Allocate logical blocks for prefill. - New blocks are allocated on GPU when possible. If GPU is full and all - GPU blocks belong to this sequence (can't evict), remaining blocks - are allocated to CPU for chunked prefill. + In cpu_primary mode (Ping-Pong): All blocks are allocated to CPU. + In legacy mode: Blocks are allocated to GPU when possible, overflow to CPU. """ assert not seq.block_table, "Sequence already has blocks" + # Ping-Pong模式:所有blocks都分配到CPU + if self.cpu_primary: + return self.allocate_cpu_only(seq) + + # Legacy模式:GPU为主,CPU为overflow h = -1 cache_miss = False @@ -451,13 +459,22 @@ class HybridKVCacheManager(KVCacheManager): block.hash = -1 block.token_ids = [] - # New decode blocks go to GPU - gpu_slot = self._allocate_gpu_slot() - block.location = BlockLocation.GPU - block.gpu_slot = gpu_slot - - self.gpu_slot_to_logical[gpu_slot] = logical_id - self.policy.on_block_allocated(gpu_slot, self.current_step) + if self.cpu_primary: + # Ping-Pong模式:新block分配到CPU + if not self.free_cpu_blocks: + raise RuntimeError("No free CPU blocks for decode") + cpu_block_id = self.free_cpu_blocks.popleft() + block.location = BlockLocation.CPU + block.cpu_block_id = cpu_block_id + block.gpu_slot = -1 + self.cpu_block_to_logical[cpu_block_id] = logical_id + else: + # Legacy模式:新block分配到GPU + gpu_slot = self._allocate_gpu_slot() + block.location = BlockLocation.GPU + block.gpu_slot = gpu_slot + self.gpu_slot_to_logical[gpu_slot] = logical_id + self.policy.on_block_allocated(gpu_slot, self.current_step) block_table.append(logical_id) @@ -993,6 +1010,158 @@ class HybridKVCacheManager(KVCacheManager): break return pos + # ========== Ping-Pong 双缓冲支持 ========== + + def allocate_cpu_only(self, seq: Sequence) -> None: + """ + 为序列分配 CPU blocks(用于 Ping-Pong 模式)。 + + 与 allocate() 不同,这里所有 blocks 都分配到 CPU, + GPU 只用作工作缓冲区。 + + Args: + seq: 要分配的序列 + """ + assert not seq.block_table, "Sequence already has blocks" + + for i in range(seq.num_blocks): + # 分配 CPU block + if not self.free_cpu_blocks: + raise RuntimeError( + f"No free CPU blocks. Need {seq.num_blocks}, " + f"available: {len(self.free_cpu_blocks)}" + ) + + cpu_block_id = self.free_cpu_blocks.popleft() + + # 分配逻辑块 + logical_id = self.free_logical_ids.popleft() + block = self.logical_blocks[logical_id] + block.ref_count = 1 + block.location = BlockLocation.CPU + block.cpu_block_id = cpu_block_id + block.gpu_slot = -1 + + self.cpu_block_to_logical[cpu_block_id] = logical_id + seq.block_table.append(logical_id) + + def get_cpu_block_table(self, seq: Sequence) -> List[int]: + """ + 获取序列的 CPU block ID 列表。 + + Args: + seq: 序列 + + Returns: + CPU block IDs 列表,按序列顺序 + """ + cpu_blocks = [] + for logical_id in seq.block_table: + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.CPU: + cpu_blocks.append(block.cpu_block_id) + else: + # 如果 block 在 GPU 上,它应该有一个对应的 CPU block + # 在 Ping-Pong 模式下,所有数据最终都在 CPU 上 + raise RuntimeError( + f"Block {logical_id} not on CPU (location={block.location}). " + f"In Ping-Pong mode, all blocks should be on CPU." + ) + return cpu_blocks + + def get_all_cpu_blocks(self, seq: Sequence) -> Tuple[List[int], List[int]]: + """ + 获取序列的所有 CPU blocks 及其逻辑 ID。 + + Args: + seq: 序列 + + Returns: + (cpu_block_ids, logical_ids) + """ + cpu_block_ids = [] + logical_ids = [] + for logical_id in seq.block_table: + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.CPU: + cpu_block_ids.append(block.cpu_block_id) + logical_ids.append(logical_id) + return cpu_block_ids, logical_ids + + def allocate_next_cpu_block(self, seq: Sequence) -> int: + """ + 为序列分配下一个 CPU block(用于 decode 时新 token)。 + + Args: + seq: 序列 + + Returns: + 新分配的 CPU block ID + """ + if not self.free_cpu_blocks: + raise RuntimeError("No free CPU blocks") + + cpu_block_id = self.free_cpu_blocks.popleft() + logical_id = self.free_logical_ids.popleft() + + block = self.logical_blocks[logical_id] + block.ref_count = 1 + block.location = BlockLocation.CPU + block.cpu_block_id = cpu_block_id + block.gpu_slot = -1 + + self.cpu_block_to_logical[cpu_block_id] = logical_id + seq.block_table.append(logical_id) + + return cpu_block_id + + def get_last_cpu_block(self, seq: Sequence) -> int: + """ + 获取序列最后一个 block 的 CPU block ID。 + + 如果最后一个 block 不在 CPU 上,返回 -1。 + + Args: + seq: 序列 + + Returns: + CPU block ID,如果不在 CPU 上则返回 -1 + """ + if not seq.block_table: + return -1 + + last_logical_id = seq.block_table[-1] + block = self.logical_blocks[last_logical_id] + + if block.location == BlockLocation.CPU: + return block.cpu_block_id + return -1 + + def get_write_slot_for_pingpong(self, seq: Sequence) -> int: + """ + 获取 Ping-Pong decode 时新 KV 写入的 GPU slot。 + + 策略:使用序列所需 chunks 数决定最后用的是 Ping 还是 Pong buffer, + 然后使用该 buffer 的最后一个 slot。 + + Args: + seq: 序列 + + Returns: + GPU slot ID + """ + cpu_blocks, _ = self.get_all_cpu_blocks(seq) + ping_size = self.offload_engine.ping_size + num_chunks = (len(cpu_blocks) + ping_size - 1) // ping_size if cpu_blocks else 0 + + # 最后一个 chunk 用的是哪个 buffer + if num_chunks % 2 == 1 or num_chunks == 0: + # 奇数个 chunk(或0个),最后用的是 ping + return self.offload_engine.ping_slots[-1] + else: + # 偶数个 chunk,最后用的是 pong + return self.offload_engine.pong_slots[-1] + def __repr__(self) -> str: return ( f"HybridKVCacheManager(\n" diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 152d64f..6a8c86d 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -64,6 +64,14 @@ class OffloadEngine: self.kv_dim = num_kv_heads * head_dim self.block_numel = block_size * self.kv_dim + # ========== Ping-Pong 双缓冲配置 ========== + assert num_gpu_blocks >= 2, "Ping-Pong需要至少2个GPU blocks" + self.ping_size = num_gpu_blocks // 2 + self.pong_size = num_gpu_blocks - self.ping_size + self.ping_slots = list(range(self.ping_size)) # [0, 1, 2, ...] + self.pong_slots = list(range(self.ping_size, num_gpu_blocks)) # [ping_size, ping_size+1, ...] + self.num_gpu_slots = num_gpu_blocks # alias + # ========== Fixed-address GPU KV cache ========== # Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] self.k_cache_gpu = torch.empty( @@ -103,6 +111,17 @@ class OffloadEngine: self.compute_stream = torch.cuda.current_stream() self._stream_idx = 0 + # ========== Ping-Pong 专用 stream 和事件 ========== + self.pingpong_stream = torch.cuda.Stream() # 专用于Ping-Pong传输 + + # 同步事件 - 加载完成 + self.ping_ready = torch.cuda.Event() + self.pong_ready = torch.cuda.Event() + + # 同步事件 - offload完成 + self.ping_offload_done = torch.cuda.Event() + self.pong_offload_done = torch.cuda.Event() + # ========== Event tracking for async transfers ========== self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {} @@ -516,7 +535,211 @@ class OffloadEngine: f" kv_heads={self.num_kv_heads},\n" f" head_dim={self.head_dim},\n" f" dtype={self.dtype},\n" + f" ping_size={self.ping_size}, pong_size={self.pong_size},\n" f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n" f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n" f")" - ) \ No newline at end of file + ) + + # ========== Ping-Pong 双缓冲方法 ========== + + def load_to_ping(self, cpu_block_ids: List[int]) -> None: + """ + 异步加载CPU blocks到Ping buffer。 + + Args: + cpu_block_ids: 要加载的CPU block IDs列表 + """ + if not cpu_block_ids: + self.ping_ready.record(self.pingpong_stream) + return + + num_to_load = min(len(cpu_block_ids), self.ping_size) + logger.debug(f"Ping load: CPU{cpu_block_ids[:num_to_load]} -> GPU ping slots {self.ping_slots[:num_to_load]}") + + with torch.cuda.stream(self.pingpong_stream): + for i in range(num_to_load): + cpu_id = cpu_block_ids[i] + gpu_slot = self.ping_slots[i] + # 所有层一起复制 + self.k_cache_gpu[:, gpu_slot].copy_( + self.k_cache_cpu[:, cpu_id], non_blocking=True + ) + self.v_cache_gpu[:, gpu_slot].copy_( + self.v_cache_cpu[:, cpu_id], non_blocking=True + ) + self.ping_ready.record(self.pingpong_stream) + + def load_to_pong(self, cpu_block_ids: List[int]) -> None: + """ + 异步加载CPU blocks到Pong buffer。 + + Args: + cpu_block_ids: 要加载的CPU block IDs列表 + """ + if not cpu_block_ids: + self.pong_ready.record(self.pingpong_stream) + return + + num_to_load = min(len(cpu_block_ids), self.pong_size) + logger.debug(f"Pong load: CPU{cpu_block_ids[:num_to_load]} -> GPU pong slots {self.pong_slots[:num_to_load]}") + + with torch.cuda.stream(self.pingpong_stream): + for i in range(num_to_load): + cpu_id = cpu_block_ids[i] + gpu_slot = self.pong_slots[i] + self.k_cache_gpu[:, gpu_slot].copy_( + self.k_cache_cpu[:, cpu_id], non_blocking=True + ) + self.v_cache_gpu[:, gpu_slot].copy_( + self.v_cache_cpu[:, cpu_id], non_blocking=True + ) + self.pong_ready.record(self.pingpong_stream) + + def wait_ping(self) -> None: + """等待Ping buffer加载完成。""" + self.compute_stream.wait_event(self.ping_ready) + + def wait_pong(self) -> None: + """等待Pong buffer加载完成。""" + self.compute_stream.wait_event(self.pong_ready) + + def offload_buffer_to_cpu( + self, + buffer: str, + cpu_block_ids: List[int], + ) -> None: + """ + 异步将buffer中的KV offload到CPU。 + + Args: + buffer: "ping" 或 "pong" + cpu_block_ids: 目标CPU block IDs列表 + """ + slots = self.ping_slots if buffer == "ping" else self.pong_slots + event = self.ping_offload_done if buffer == "ping" else self.pong_offload_done + + if not cpu_block_ids: + event.record(self.pingpong_stream) + return + + num_to_offload = min(len(cpu_block_ids), len(slots)) + logger.debug(f"{buffer.capitalize()} offload: GPU {slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}") + + with torch.cuda.stream(self.pingpong_stream): + # 等待计算完成 + self.pingpong_stream.wait_stream(self.compute_stream) + + for i in range(num_to_offload): + gpu_slot = slots[i] + cpu_id = cpu_block_ids[i] + self.k_cache_cpu[:, cpu_id].copy_( + self.k_cache_gpu[:, gpu_slot], non_blocking=True + ) + self.v_cache_cpu[:, cpu_id].copy_( + self.v_cache_gpu[:, gpu_slot], non_blocking=True + ) + event.record(self.pingpong_stream) + + def offload_slot_to_cpu( + self, + gpu_slot: int, + cpu_block_id: int, + ) -> None: + """ + 异步将单个GPU slot的KV offload到CPU。 + + Args: + gpu_slot: GPU slot ID + cpu_block_id: 目标CPU block ID + """ + logger.debug(f"Slot offload: GPU[{gpu_slot}] -> CPU[{cpu_block_id}]") + + with torch.cuda.stream(self.pingpong_stream): + self.pingpong_stream.wait_stream(self.compute_stream) + self.k_cache_cpu[:, cpu_block_id].copy_( + self.k_cache_gpu[:, gpu_slot], non_blocking=True + ) + self.v_cache_cpu[:, cpu_block_id].copy_( + self.v_cache_gpu[:, gpu_slot], non_blocking=True + ) + + def wait_ping_offload_done(self) -> None: + """等待Ping buffer offload完成。""" + self.compute_stream.wait_event(self.ping_offload_done) + + def wait_pong_offload_done(self) -> None: + """等待Pong buffer offload完成。""" + self.compute_stream.wait_event(self.pong_offload_done) + + def wait_all_offload_done(self) -> None: + """等待所有offload完成。""" + self.pingpong_stream.synchronize() + + def get_kv_for_ping_slots( + self, + layer_id: int, + num_slots: int, + ) -> Tuple[Tensor, Tensor]: + """ + 获取Ping buffer中指定数量slots的KV。 + + Args: + layer_id: 层ID + num_slots: 需要的slot数量 + + Returns: + (k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim] + """ + slots = self.ping_slots[:num_slots] + k = self.k_cache_gpu[layer_id, slots] # [num_slots, block_size, heads, dim] + v = self.v_cache_gpu[layer_id, slots] + # Reshape: [num_slots, block_size, heads, dim] -> [1, num_slots*block_size, heads, dim] + k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) + v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) + return k, v + + def get_kv_for_pong_slots( + self, + layer_id: int, + num_slots: int, + ) -> Tuple[Tensor, Tensor]: + """ + 获取Pong buffer中指定数量slots的KV。 + + Args: + layer_id: 层ID + num_slots: 需要的slot数量 + + Returns: + (k_cache, v_cache),shape: [1, num_slots * block_size, kv_heads, head_dim] + """ + slots = self.pong_slots[:num_slots] + k = self.k_cache_gpu[layer_id, slots] + v = self.v_cache_gpu[layer_id, slots] + k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) + v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) + return k, v + + def get_kv_for_slots( + self, + layer_id: int, + gpu_slots: List[int], + ) -> Tuple[Tensor, Tensor]: + """ + 获取指定GPU slots的KV。 + + Args: + layer_id: 层ID + gpu_slots: GPU slot IDs列表 + + Returns: + (k_cache, v_cache),shape: [1, len(slots) * block_size, kv_heads, head_dim] + """ + if not gpu_slots: + return None, None + k = self.k_cache_gpu[layer_id, gpu_slots] + v = self.v_cache_gpu[layer_id, gpu_slots] + k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) + v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) + return k, v \ No newline at end of file diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index bb00fb7..133a71b 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -97,51 +97,89 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute attention with chunked KV from CPU cache. + Compute attention with Ping-Pong dual buffer for chunked prefill. For chunked prefill: - 1. Load previous KV from CPU for this layer - 2. Compute attention against previous KV (no causal mask) + 1. Load previous KV from CPU using Ping-Pong (if any previous chunks) + 2. Compute attention against previous KV chunks (no causal mask) 3. Compute attention against current chunk's KV (causal) - 4. Merge results using online softmax + 4. Merge all results using online softmax """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs # q, k, v shape: [total_tokens, num_heads, head_dim] - total_tokens = q.shape[0] - # Reshape for flash attention: [batch, seq, heads, dim] q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim] k_batched = k.unsqueeze(0) v_batched = v.unsqueeze(0) - accumulated_o = None - accumulated_lse = None + o_acc = None + lse_acc = None - # Load previous KV from CPU for this layer - if context.offload_engine is not None and self.layer_id >= 0: - # Get the kvcache_manager from context - kvcache_manager = context.offload_engine + # Load previous KV from CPU using Ping-Pong + # Note: context.offload_engine is actually HybridKVCacheManager + kvcache_manager = context.offload_engine + seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None - # For each sequence in the chunk, load previous KV - # Currently assuming single sequence - if hasattr(context, 'chunked_seq') and context.chunked_seq is not None: - prev_k, prev_v = kvcache_manager.load_prev_kv_for_layer( - context.chunked_seq, - self.layer_id, - ) + if kvcache_manager is not None and seq is not None and self.layer_id >= 0: + # Get prefilled CPU blocks (blocks already written in previous chunks) + cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) - if prev_k is not None and prev_v is not None: - # Compute attention against previous KV (no causal mask) + if cpu_block_table: + offload_engine = kvcache_manager.offload_engine + ping_size = offload_engine.ping_size + num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size + current_buffer = "ping" + + # Prefetch first chunk to Ping buffer + first_chunk_end = min(ping_size, len(cpu_block_table)) + first_chunk_ids = cpu_block_table[:first_chunk_end] + offload_engine.load_to_ping(first_chunk_ids) + + for chunk_idx in range(num_chunks): + start = chunk_idx * ping_size + end = min(start + ping_size, len(cpu_block_table)) + num_blocks_in_chunk = end - start + + # Prefetch next chunk to OTHER buffer + if chunk_idx + 1 < num_chunks: + next_start = end + next_end = min(next_start + ping_size, len(cpu_block_table)) + next_chunk_ids = cpu_block_table[next_start:next_end] + if current_buffer == "ping": + offload_engine.load_to_pong(next_chunk_ids) + else: + offload_engine.load_to_ping(next_chunk_ids) + + # Wait for current buffer and get KV + if current_buffer == "ping": + offload_engine.wait_ping() + prev_k, prev_v = offload_engine.get_kv_for_ping_slots( + self.layer_id, num_blocks_in_chunk + ) + else: + offload_engine.wait_pong() + prev_k, prev_v = offload_engine.get_kv_for_pong_slots( + self.layer_id, num_blocks_in_chunk + ) + + # Compute attention against this chunk (no causal mask) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=self.scale, - causal=False, # No causal mask for previous context + causal=False, ) - accumulated_o = prev_o - accumulated_lse = prev_lse + + # Merge with accumulated + if o_acc is None: + o_acc, lse_acc = prev_o, prev_lse + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + + # Switch buffer + current_buffer = "pong" if current_buffer == "ping" else "ping" # Compute attention against current chunk's KV (with causal mask) current_o, current_lse = flash_attn_with_lse( @@ -149,17 +187,14 @@ class Attention(nn.Module): k_batched, v_batched, softmax_scale=self.scale, - causal=True, # Causal mask for current chunk + causal=True, ) # Merge with accumulated - if accumulated_o is None: + if o_acc is None: final_o = current_o else: - final_o, _ = merge_attention_outputs( - accumulated_o, accumulated_lse, - current_o, current_lse, - ) + final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] return final_o.squeeze(0) @@ -172,12 +207,13 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute decode attention with KV spread across CPU and GPU. + Compute decode attention with Ping-Pong dual buffer. - Uses chunked attention similar to chunked prefill: - 1. Process blocks on GPU first (if any) - 2. Load CPU blocks in chunks to GPU slots (per-layer) - 3. Compute attention for each chunk, merge with online softmax + All KV is stored on CPU. Uses Ping-Pong buffers on GPU: + 1. Load first chunk to Ping buffer + 2. While computing on current buffer, prefetch next chunk to other buffer + 3. Alternate between Ping and Pong buffers + 4. Merge attention outputs using online softmax (LSE) """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs @@ -185,62 +221,73 @@ class Attention(nn.Module): # Need: [batch, seqlen, heads, dim] q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] + # Note: context.offload_engine is actually HybridKVCacheManager kvcache_manager = context.offload_engine seq = context.chunked_seq + # Get all CPU blocks for this sequence + cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq) + if not cpu_block_table: + raise RuntimeError("Chunked decode attention failed: no CPU blocks available") + + # Get the actual offload_engine for Ping-Pong operations + offload_engine = kvcache_manager.offload_engine + + # Calculate chunk info + ping_size = offload_engine.ping_size + num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size + o_acc = None lse_acc = None + current_buffer = "ping" - # Step 1: Process blocks already on GPU (if any) - gpu_slots, _ = kvcache_manager.get_gpu_blocks_for_decode(seq) - if gpu_slots: - k_gpu, v_gpu = kvcache_manager.get_kv_for_gpu_slots(self.layer_id, gpu_slots) - o_gpu, lse_gpu = flash_attn_with_lse( - q_batched, k_gpu, v_gpu, + # Prefetch first chunk to Ping buffer (loads all layers at once) + first_chunk_end = min(ping_size, len(cpu_block_table)) + first_chunk_ids = cpu_block_table[:first_chunk_end] + offload_engine.load_to_ping(first_chunk_ids) + + for chunk_idx in range(num_chunks): + start = chunk_idx * ping_size + end = min(start + ping_size, len(cpu_block_table)) + num_blocks_in_chunk = end - start + + # Prefetch next chunk to OTHER buffer (overlapped with current computation) + if chunk_idx + 1 < num_chunks: + next_start = end + next_end = min(next_start + ping_size, len(cpu_block_table)) + next_chunk_ids = cpu_block_table[next_start:next_end] + if current_buffer == "ping": + offload_engine.load_to_pong(next_chunk_ids) + else: + offload_engine.load_to_ping(next_chunk_ids) + + # Wait for current buffer to be ready and get KV + if current_buffer == "ping": + offload_engine.wait_ping() + k_chunk, v_chunk = offload_engine.get_kv_for_ping_slots( + self.layer_id, num_blocks_in_chunk + ) + else: + offload_engine.wait_pong() + k_chunk, v_chunk = offload_engine.get_kv_for_pong_slots( + self.layer_id, num_blocks_in_chunk + ) + + # Compute attention for this chunk + o_chunk, lse_chunk = flash_attn_with_lse( + q_batched, k_chunk, v_chunk, softmax_scale=self.scale, causal=False, ) - o_acc, lse_acc = o_gpu, lse_gpu - # Step 2: Process CPU blocks in chunks - # Get chunk info from kvcache_manager - cpu_block_ids, cpu_logical_ids, num_chunks = kvcache_manager.get_decode_chunk_info(seq) + # Merge with accumulated + if o_acc is None: + o_acc, lse_acc = o_chunk, lse_chunk + else: + o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) - if num_chunks > 0: - # Use num_gpu_slots - 1 to avoid the reserved slot (used for write block) - chunk_size = kvcache_manager.num_gpu_slots - 1 - - for chunk_idx in range(num_chunks): - start = chunk_idx * chunk_size - end = min(start + chunk_size, len(cpu_block_ids)) - chunk_cpu_ids = cpu_block_ids[start:end] - - # Load this chunk to GPU slots 0, 1, 2, ... for THIS LAYER - # (slot num_gpu_slots-1 is reserved for write block) - gpu_slots_for_chunk = list(range(len(chunk_cpu_ids))) - kvcache_manager.offload_engine.load_cpu_blocks_to_gpu_slots( - self.layer_id, - chunk_cpu_ids, - gpu_slots_for_chunk, - ) - - # Get KV for this chunk - k_chunk, v_chunk = kvcache_manager.get_kv_for_gpu_slots( - self.layer_id, gpu_slots_for_chunk - ) - - # Compute attention for this chunk - o_chunk, lse_chunk = flash_attn_with_lse( - q_batched, k_chunk, v_chunk, - softmax_scale=self.scale, - causal=False, - ) - - # Merge with accumulated - if o_acc is None: - o_acc, lse_acc = o_chunk, lse_chunk - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) + # Switch buffer for next iteration + current_buffer = "pong" if current_buffer == "ping" else "ping" if o_acc is None: raise RuntimeError("Chunked decode attention failed: no KV available") diff --git a/tests/test_chunked_attention.py b/tests/test_chunked_attention.py index e9922c1..56d1d3f 100644 --- a/tests/test_chunked_attention.py +++ b/tests/test_chunked_attention.py @@ -14,63 +14,66 @@ os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG" from nanovllm import LLM, SamplingParams -def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=16): +def create_long_context_prompt(target_tokens: int) -> str: + """ + Create a meaningful long context prompt with a question at the end. + The answer depends on information scattered throughout the context. + """ + # Key facts to embed in the context + facts = [ + "The capital of France is Paris.", + "The Eiffel Tower was built in 1889.", + "Python was created by Guido van Rossum.", + "The speed of light is approximately 299,792 kilometers per second.", + "Mount Everest is 8,848 meters tall.", + ] + + # Padding text to reach target length + padding_paragraph = """ +This is additional context information that helps extend the length of the prompt. +Machine learning has revolutionized many fields including computer vision, natural language processing, and robotics. +Deep neural networks can learn complex patterns from large amounts of data. +The transformer architecture has become the foundation of modern language models. +Attention mechanisms allow models to focus on relevant parts of the input. +""" + + # Build the prompt + prompt_parts = [] + + # Add instruction + prompt_parts.append("Please read the following information carefully and answer the question at the end.\n\n") + + # Add facts at different positions + current_tokens = 50 # approximate tokens so far + tokens_per_padding = 80 # approximate tokens per padding paragraph + fact_interval = target_tokens // (len(facts) + 1) + + fact_idx = 0 + while current_tokens < target_tokens - 100: + # Add padding + prompt_parts.append(padding_paragraph) + current_tokens += tokens_per_padding + + # Add a fact at intervals + if fact_idx < len(facts) and current_tokens > fact_interval * (fact_idx + 1): + prompt_parts.append(f"\n[Important Fact #{fact_idx + 1}]: {facts[fact_idx]}\n") + current_tokens += 20 + fact_idx += 1 + + # Add the question at the end + prompt_parts.append("\n\nQuestion: Based on the information above, what is the capital of France and when was the Eiffel Tower built? Please answer briefly.\n\nAnswer:") + + return "".join(prompt_parts) + + +def test_chunked_prefill(num_gpu_blocks=10, input_len=8192, output_len=64): """Test chunked prefill with limited GPU blocks.""" path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") - total_blocks = (input_len + 255) // 256 print(f"=" * 60) - print(f"Chunked Prefill Test") + print(f"Chunked Prefill Test (Ping-Pong)") print(f"=" * 60) - print(f" input_len: {input_len} tokens") - print(f" total_blocks: {total_blocks}") - print(f" num_gpu_blocks: {num_gpu_blocks}") - print(f" blocks_on_cpu: {max(0, total_blocks - num_gpu_blocks)}") - print() - - llm = LLM( - path, - enforce_eager=False, - max_model_len=16 * 1024, # 16K is enough for 8K test - max_num_batched_tokens=16 * 1024, - enable_cpu_offload=True, - cpu_memory_gb=4.0, - num_gpu_blocks=num_gpu_blocks, - ) - - print(f"LLM initialized:") - print(f" num_gpu_kvcache_blocks: {llm.model_runner.config.num_gpu_kvcache_blocks}") - print(f" num_cpu_kvcache_blocks: {llm.model_runner.config.num_cpu_kvcache_blocks}") - print() - - # Create prompt with approximate token count - prompt = "Hello " * (input_len // 2) - - print(f"Running generation...") - outputs = llm.generate( - [prompt], - SamplingParams(temperature=0.6, max_tokens=output_len), - use_tqdm=True, - ) - - print() - print(f"Output tokens: {len(outputs[0]['token_ids'])}") - print(f"Output text (first 100 chars): {outputs[0]['text'][:100]}") - print() - return outputs - - -def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=64): - """Test chunked decode with limited GPU blocks.""" - path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") - - total_blocks = (input_len + 255) // 256 - print(f"=" * 60) - print(f"Chunked Decode Test") - print(f"=" * 60) - print(f" input_len: {input_len} tokens") - print(f" output_len: {output_len} tokens") - print(f" total_blocks: {total_blocks}") + print(f" target_input_len: ~{input_len} tokens") print(f" num_gpu_blocks: {num_gpu_blocks}") print() @@ -80,27 +83,62 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=64): max_model_len=16 * 1024, max_num_batched_tokens=16 * 1024, enable_cpu_offload=True, - cpu_memory_gb=4.0, num_gpu_blocks=num_gpu_blocks, ) - - print(f"LLM initialized:") - print(f" num_gpu_kvcache_blocks: {llm.model_runner.config.num_gpu_kvcache_blocks}") - print(f" num_cpu_kvcache_blocks: {llm.model_runner.config.num_cpu_kvcache_blocks}") print() - prompt = "Hello " * (input_len // 2) + # Create meaningful prompt + prompt = create_long_context_prompt(input_len) print(f"Running generation...") outputs = llm.generate( [prompt], - SamplingParams(temperature=0.6, max_tokens=output_len), - use_tqdm=True, + SamplingParams(temperature=0.1, max_tokens=output_len), # low temperature for more deterministic output + use_tqdm=False, ) print() print(f"Output tokens: {len(outputs[0]['token_ids'])}") - print(f"Output text (first 100 chars): {outputs[0]['text'][:100]}") + print(f"Output text:\n{outputs[0]['text']}") + print() + return outputs + + +def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=128): + """Test chunked decode with limited GPU blocks.""" + path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") + + print(f"=" * 60) + print(f"Chunked Decode Test (Ping-Pong)") + print(f"=" * 60) + print(f" target_input_len: ~{input_len} tokens") + print(f" output_len: {output_len} tokens") + print(f" num_gpu_blocks: {num_gpu_blocks}") + print() + + llm = LLM( + path, + enforce_eager=False, + max_model_len=16 * 1024, + max_num_batched_tokens=16 * 1024, + enable_cpu_offload=True, + num_gpu_blocks=num_gpu_blocks, + ) + print() + + # Create meaningful prompt + prompt = create_long_context_prompt(input_len) + + print(f"Running generation...") + outputs = llm.generate( + [prompt], + SamplingParams(temperature=0.1, max_tokens=output_len), + use_tqdm=False, + ) + + print() + print(f"Output tokens: {len(outputs[0]['token_ids'])}") + print(f"Output text:\n{outputs[0]['text']}") print() return outputs @@ -108,7 +146,7 @@ def test_chunked_decode(num_gpu_blocks=10, input_len=8192, output_len=64): if __name__ == "__main__": # Parse arguments num_gpu_blocks = int(sys.argv[1]) if len(sys.argv) > 1 else 10 - input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 8192 - output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 32 + input_len = int(sys.argv[2]) if len(sys.argv) > 2 else 2048 + output_len = int(sys.argv[3]) if len(sys.argv) > 3 else 64 test_chunked_prefill(num_gpu_blocks, input_len, output_len) \ No newline at end of file