From e85c2b477625b8a7476dd6cfa09bf8fd0a9c9e27 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 10 Dec 2025 22:34:00 +0800 Subject: [PATCH] [fix] Fixed kvcache offload bugs. --- nanovllm/config.py | 1 + nanovllm/engine/model_runner.py | 116 +++++++------ nanovllm/kvcache/__init__.py | 2 + nanovllm/kvcache/hybrid_manager.py | 35 ++-- nanovllm/kvcache/offload_engine.py | 254 +++++++++++++++++++++++++++-- nanovllm/layers/attention.py | 153 +++++++++-------- nanovllm/utils/context.py | 4 + 7 files changed, 409 insertions(+), 156 deletions(-) diff --git a/nanovllm/config.py b/nanovllm/config.py index 4fbae3b..1bc27f0 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -22,6 +22,7 @@ class Config: offload_policy: str = "lru" # "lru", "fifo", or full class path num_transfer_streams: int = 4 # Number of CUDA streams for async transfers num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available) + num_prefetch_blocks: int = 2 # Prefetch区的block数量,用于三区域GPU Buffer设计 # Computed fields for offload (set in __post_init__ or by ModelRunner) num_gpu_kvcache_blocks: int = -1 diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index ccb58d3..db6dfe3 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -148,19 +148,39 @@ class ModelRunner: dtype=hf_config.torch_dtype, ) - # Log KV cache allocation info + # Log KV cache allocation info with detailed per-token breakdown gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2) cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2) total_memory_mb = gpu_memory_mb + cpu_memory_mb + # Calculate per-token KV cache usage + # KV per token = 2 (K+V) * num_layers * kv_heads * head_dim * dtype_size + dtype_size = 2 if hf_config.torch_dtype in [torch.float16, torch.bfloat16] else 4 + per_token_kv_bytes = 2 * hf_config.num_hidden_layers * num_kv_heads * head_dim * dtype_size + per_token_kv_kb = per_token_kv_bytes / 1024 + + logger.info( + f"KV Cache per-token: {per_token_kv_kb:.2f}KB " + f"(2 * {hf_config.num_hidden_layers}layers * {num_kv_heads}kv_heads * {head_dim}head_dim * {dtype_size}bytes)" + ) + logger.info( + f"KV Cache per-block: {block_bytes / (1024**2):.2f}MB " + f"({per_token_kv_kb:.2f}KB * {self.block_size}tokens)" + ) + if config.enable_cpu_offload: + ping_size = config.num_gpu_kvcache_blocks // 2 + tokens_per_ping = ping_size * self.block_size logger.info( f"KV Cache allocated (Ping-Pong mode): " f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), " f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), " - f"Total={total_memory_mb:.1f}MB, " - f"block_size={self.block_size}, " - f"ping_size={config.num_gpu_kvcache_blocks // 2}" + f"Total={total_memory_mb:.1f}MB" + ) + logger.info( + f"Ping-Pong config: ping_size={ping_size} blocks, " + f"tokens_per_chunk={tokens_per_ping}, " + f"block_size={self.block_size}" ) else: logger.info( @@ -392,12 +412,12 @@ class ModelRunner: def _should_use_pingpong(self, seqs: list[Sequence], is_prefill: bool) -> bool: """ - Check if Ping-Pong mode should be used. + Check if 三区域 mode should be used. - Use Ping-Pong when: + Use 三区域 when: - CPU offload is enabled - There are blocks on CPU (either allocated there or offloaded) - - Sequence exceeds GPU capacity + - Sequence exceeds GPU Compute区 capacity """ if not hasattr(self.kvcache_manager, 'offload_engine'): return False @@ -409,12 +429,12 @@ class ModelRunner: # Check if any blocks are on CPU cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq) if cpu_blocks: - # Has CPU blocks - use Ping-Pong + # Has CPU blocks - use 三区域 return True - # Check if sequence needs more blocks than GPU can hold - ping_size = self.kvcache_manager.offload_engine.ping_size - if seq.num_blocks > ping_size: + # Check if sequence needs more blocks than GPU Compute区 can hold + compute_size = self.kvcache_manager.offload_engine.num_compute_blocks + if seq.num_blocks > compute_size: # Needs chunked processing return True @@ -610,29 +630,28 @@ class ModelRunner: def run_pingpong_prefill(self, seqs: list[Sequence]) -> list[int]: """ - Run prefill with Ping-Pong dual buffer (CPU is primary storage). + Run prefill with 三区域 GPU buffer (CPU is primary storage). Flow: 1. All blocks are allocated to CPU (primary storage) - 2. Process tokens in chunks using Ping/Pong GPU buffers - 3. After each chunk, offload from GPU to CPU - 4. Alternate between Ping and Pong buffers + 2. Process tokens in chunks using Compute区 GPU buffer + 3. After each chunk, offload from Compute区 to CPU + 4. Prefetch区 用于加载 previous KV(如果有的话) """ import sys - assert len(seqs) == 1, "Ping-Pong prefill only supports single sequence" + assert len(seqs) == 1, "三区域 prefill only supports single sequence" seq = seqs[0] offload_engine = self.kvcache_manager.offload_engine - ping_size = offload_engine.ping_size - tokens_per_chunk = ping_size * self.block_size + compute_size = offload_engine.num_compute_blocks + tokens_per_chunk = compute_size * self.block_size total_tokens = len(seq) - print(f"[Ping-Pong Prefill] Starting: {total_tokens} tokens, " - f"ping_size={ping_size} blocks, chunk={tokens_per_chunk} tokens", + print(f"[三区域 Prefill] Starting: {total_tokens} tokens, " + f"compute_size={compute_size} blocks, chunk={tokens_per_chunk} tokens", file=sys.stderr) - current_buffer = "ping" chunk_num = 0 logits = None processed_tokens = 0 @@ -651,15 +670,13 @@ class ModelRunner: end_block_idx = (chunk_end + self.block_size - 1) // self.block_size num_blocks = end_block_idx - start_block_idx - print(f"[Ping-Pong Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, " - f"blocks {start_block_idx}-{end_block_idx-1}, buffer={current_buffer}", + print(f"[三区域 Prefill] Chunk {chunk_num}: tokens {chunk_start}-{chunk_end}, " + f"blocks {start_block_idx}-{end_block_idx-1}, " + f"compute_slots={offload_engine.compute_slots[:num_blocks]}", file=sys.stderr) - # Get GPU slots for this chunk (Ping or Pong buffer) - if current_buffer == "ping": - gpu_slots = offload_engine.ping_slots[:num_blocks] - else: - gpu_slots = offload_engine.pong_slots[:num_blocks] + # Get GPU slots for this chunk (使用 Compute区) + gpu_slots = offload_engine.compute_slots[:num_blocks] # Prepare inputs input_ids, positions = self._prepare_pingpong_chunk( @@ -678,24 +695,19 @@ class ModelRunner: logical_id = seq.block_table[i] self.kvcache_manager.prefilled_blocks.add(logical_id) - # Offload this chunk from GPU to CPU (async) + # Offload this chunk from Compute区 to CPU (async) chunk_cpu_blocks = cpu_block_ids[start_block_idx:end_block_idx] - offload_engine.offload_buffer_to_cpu(current_buffer, chunk_cpu_blocks) + offload_engine.offload_compute_to_cpu(chunk_cpu_blocks) - # Switch buffer for next chunk - if current_buffer == "ping": - offload_engine.wait_ping_offload_done() - current_buffer = "pong" - else: - offload_engine.wait_pong_offload_done() - current_buffer = "ping" + # Wait for offload to complete before next chunk + offload_engine.wait_all_offload_done() processed_tokens = chunk_end # Wait for all offloads to complete offload_engine.wait_all_offload_done() - print(f"[Ping-Pong Prefill] Complete: {chunk_num} chunks", file=sys.stderr) + print(f"[三区域 Prefill] Complete: {chunk_num} chunks", file=sys.stderr) # Sample from last logits temperatures = self.prepare_sample(seqs) if self.rank == 0 else None @@ -764,25 +776,26 @@ class ModelRunner: def run_pingpong_decode(self, seqs: list[Sequence]) -> list[int]: """ - Run decode with Ping-Pong dual buffer. + Run decode with 三区域 GPU buffer. - All KV is on CPU. Uses Ping-Pong to load KV chunks and compute attention. - New token's KV is written to GPU then offloaded to CPU. + All KV is on CPU. Uses Decode区 to write new KV, Compute/Prefetch区 to load KV chunks. + New token's KV is written to Decode区 (slot 0) then offloaded to CPU. + + 关键:Decode区 永远不会被 Compute/Prefetch 覆盖,专门用于写入新KV。 """ - assert len(seqs) == 1, "Ping-Pong decode only supports single sequence" + assert len(seqs) == 1, "三区域 decode only supports single sequence" seq = seqs[0] + offload_engine = self.kvcache_manager.offload_engine + # Prepare inputs input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) - # Get write slot for new KV (will use last slot of the buffer used for final chunk) - write_slot = self.kvcache_manager.get_write_slot_for_pingpong(seq) - - # Calculate position in block for slot mapping - last_block_idx = seq.num_blocks - 1 + # 使用 Decode区 (slot 0) 写入新 KV + decode_slot = offload_engine.decode_slot # = 0 pos_in_block = (len(seq) - 1) % self.block_size - slot = write_slot * self.block_size + pos_in_block + slot = decode_slot * self.block_size + pos_in_block slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) @@ -794,17 +807,18 @@ class ModelRunner: is_chunked_prefill=True, # Use chunked attention path offload_engine=self.kvcache_manager, chunked_seq=seq, + decode_pos_in_block=pos_in_block, ) # Run model forward pass logits = self.run_model(input_ids, positions, is_prefill=False) reset_context() - # Offload new KV from write_slot to CPU + # Offload new KV from Decode区 to CPU last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq) if last_cpu_block >= 0: - self.kvcache_manager.offload_engine.offload_slot_to_cpu(write_slot, last_cpu_block) - self.kvcache_manager.offload_engine.wait_all_offload_done() + offload_engine.offload_decode_slot(last_cpu_block) + offload_engine.wait_all_offload_done() # Sample temperatures = self.prepare_sample(seqs) if self.rank == 0 else None diff --git a/nanovllm/kvcache/__init__.py b/nanovllm/kvcache/__init__.py index a700ce5..ef34a81 100644 --- a/nanovllm/kvcache/__init__.py +++ b/nanovllm/kvcache/__init__.py @@ -58,12 +58,14 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager: from nanovllm.kvcache.policies import get_policy policy = get_policy(getattr(config, 'offload_policy', 'lru')) + num_prefetch_blocks = getattr(config, 'num_prefetch_blocks', 2) return HybridKVCacheManager( num_gpu_slots=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=config.kvcache_block_size, policy=policy, + num_prefetch_blocks=num_prefetch_blocks, ) diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index 1d97887..691cbbf 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -9,6 +9,7 @@ Key design for CUDA Graph compatibility: 5. Graph replay only needs index updates (tiny overhead) """ +import logging from collections import deque from dataclasses import dataclass, field from enum import Enum, auto @@ -16,6 +17,8 @@ from typing import List, Tuple, Dict, Set, Optional import torch from torch import Tensor +logger = logging.getLogger(__name__) + from nanovllm.engine.sequence import Sequence from nanovllm.kvcache.base_manager import KVCacheManager from nanovllm.kvcache.offload_engine import OffloadEngine @@ -82,6 +85,7 @@ class HybridKVCacheManager(KVCacheManager): block_size: int, policy: Optional[EvictionPolicy] = None, cpu_primary: bool = True, + num_prefetch_blocks: int = 2, ): """ Initialize hybrid manager. @@ -91,14 +95,16 @@ class HybridKVCacheManager(KVCacheManager): num_cpu_blocks: Number of CPU pool blocks (overflow or primary storage) block_size: Tokens per block policy: Eviction policy (default: LRU) - cpu_primary: If True, use CPU as primary storage with Ping-Pong GPU buffer. + cpu_primary: If True, use CPU as primary storage with 三区域 GPU buffer. If False, use GPU as primary with CPU as overflow (legacy mode). + num_prefetch_blocks: Number of prefetch blocks for 三区域 GPU buffer design """ self._block_size = block_size self.num_gpu_slots = num_gpu_slots self.num_cpu_blocks = num_cpu_blocks self.total_blocks = num_gpu_slots + num_cpu_blocks - self.cpu_primary = cpu_primary # Ping-Pong mode flag + self.cpu_primary = cpu_primary # 三区域 mode flag + self.num_prefetch_blocks = num_prefetch_blocks # 三区域设计参数 # Eviction policy self.policy = policy or LRUPolicy() @@ -156,6 +162,7 @@ class HybridKVCacheManager(KVCacheManager): num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=dtype, + num_prefetch_blocks=self.num_prefetch_blocks, ) def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: @@ -948,6 +955,10 @@ class HybridKVCacheManager(KVCacheManager): block = self.logical_blocks[logical_id] if block.location == BlockLocation.CPU: cpu_blocks.append(block.cpu_block_id) + logger.debug( + f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, " + f"returned cpu_blocks={cpu_blocks}" + ) return cpu_blocks def load_prev_kv_for_layer( @@ -1139,28 +1150,18 @@ class HybridKVCacheManager(KVCacheManager): def get_write_slot_for_pingpong(self, seq: Sequence) -> int: """ - 获取 Ping-Pong decode 时新 KV 写入的 GPU slot。 + 获取三区域 decode 时新 KV 写入的 GPU slot。 - 策略:使用序列所需 chunks 数决定最后用的是 Ping 还是 Pong buffer, - 然后使用该 buffer 的最后一个 slot。 + 在三区域设计中,永远使用 Decode区 (slot 0) 写入新 KV。 + 这样可以避免与 Compute/Prefetch区 的加载操作冲突。 Args: seq: 序列 Returns: - GPU slot ID + GPU slot ID (永远是 decode_slot = 0) """ - cpu_blocks, _ = self.get_all_cpu_blocks(seq) - ping_size = self.offload_engine.ping_size - num_chunks = (len(cpu_blocks) + ping_size - 1) // ping_size if cpu_blocks else 0 - - # 最后一个 chunk 用的是哪个 buffer - if num_chunks % 2 == 1 or num_chunks == 0: - # 奇数个 chunk(或0个),最后用的是 ping - return self.offload_engine.ping_slots[-1] - else: - # 偶数个 chunk,最后用的是 pong - return self.offload_engine.pong_slots[-1] + return self.offload_engine.decode_slot def __repr__(self) -> str: return ( diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index 6a8c86d..6a3e9b3 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -53,6 +53,7 @@ class OffloadEngine: head_dim: int, dtype: torch.dtype = torch.float16, num_streams: int = 4, + num_prefetch_blocks: int = 2, ): self.num_layers = num_layers self.num_gpu_blocks = num_gpu_blocks @@ -64,31 +65,59 @@ class OffloadEngine: self.kv_dim = num_kv_heads * head_dim self.block_numel = block_size * self.kv_dim - # ========== Ping-Pong 双缓冲配置 ========== - assert num_gpu_blocks >= 2, "Ping-Pong需要至少2个GPU blocks" - self.ping_size = num_gpu_blocks // 2 - self.pong_size = num_gpu_blocks - self.ping_size - self.ping_slots = list(range(self.ping_size)) # [0, 1, 2, ...] - self.pong_slots = list(range(self.ping_size, num_gpu_blocks)) # [ping_size, ping_size+1, ...] + # ========== 三区域 GPU Buffer 配置 ========== + # 约束检查 + assert num_gpu_blocks >= 3, \ + f"至少需要3个GPU blocks: 1 decode + 1 compute + 1 prefetch, got {num_gpu_blocks}" + assert num_prefetch_blocks >= 1, \ + f"至少需要1个prefetch block, got {num_prefetch_blocks}" + assert num_gpu_blocks >= 1 + 1 + num_prefetch_blocks, \ + f"GPU blocks不足: 需要 1(decode) + 1(compute) + {num_prefetch_blocks}(prefetch), got {num_gpu_blocks}" + + # 三区域配置 + # Decode区: [0] - 固定1个block用于写入新KV + self.decode_slot = 0 + + # Compute区: [1, ..., num_gpu_blocks - num_prefetch_blocks - 1] + compute_start = 1 + compute_end = num_gpu_blocks - num_prefetch_blocks + self.compute_slots = list(range(compute_start, compute_end)) + self.num_compute_blocks = len(self.compute_slots) + + # Prefetch区: [num_gpu_blocks - num_prefetch_blocks, ..., num_gpu_blocks - 1] + prefetch_start = compute_end + self.prefetch_slots = list(range(prefetch_start, num_gpu_blocks)) + self.num_prefetch_blocks = num_prefetch_blocks + self.num_gpu_slots = num_gpu_blocks # alias + # 保留旧的ping/pong属性以兼容(后续会移除) + self.ping_size = self.num_compute_blocks + self.pong_size = self.num_prefetch_blocks + self.ping_slots = self.compute_slots.copy() + self.pong_slots = self.prefetch_slots.copy() + + logger.info(f"三区域 GPU Buffer: decode_slot={self.decode_slot}, " + f"compute_slots={self.compute_slots}, prefetch_slots={self.prefetch_slots}") + # ========== Fixed-address GPU KV cache ========== # Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] - self.k_cache_gpu = torch.empty( + # 使用 zeros 初始化以避免未初始化内存问题 + self.k_cache_gpu = torch.zeros( num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim, dtype=dtype, device="cuda" ) - self.v_cache_gpu = torch.empty( + self.v_cache_gpu = torch.zeros( num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim, dtype=dtype, device="cuda" ) # ========== Fixed-address CPU KV cache (pinned memory) ========== - self.k_cache_cpu = torch.empty( + self.k_cache_cpu = torch.zeros( num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim, dtype=dtype, device="cpu", pin_memory=True ) - self.v_cache_cpu = torch.empty( + self.v_cache_cpu = torch.zeros( num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim, dtype=dtype, device="cpu", pin_memory=True ) @@ -111,14 +140,18 @@ class OffloadEngine: self.compute_stream = torch.cuda.current_stream() self._stream_idx = 0 - # ========== Ping-Pong 专用 stream 和事件 ========== - self.pingpong_stream = torch.cuda.Stream() # 专用于Ping-Pong传输 + # ========== 三区域专用 stream 和事件 ========== + self.transfer_stream_main = torch.cuda.Stream() # 主传输stream - # 同步事件 - 加载完成 - self.ping_ready = torch.cuda.Event() - self.pong_ready = torch.cuda.Event() + # 同步事件 - 三区域加载完成 + self.compute_ready = torch.cuda.Event() + self.prefetch_ready = torch.cuda.Event() + self.decode_offload_done = torch.cuda.Event() - # 同步事件 - offload完成 + # 保留旧的ping/pong事件以兼容(后续会移除) + self.pingpong_stream = self.transfer_stream_main + self.ping_ready = self.compute_ready + self.pong_ready = self.prefetch_ready self.ping_offload_done = torch.cuda.Event() self.pong_offload_done = torch.cuda.Event() @@ -535,7 +568,7 @@ class OffloadEngine: f" kv_heads={self.num_kv_heads},\n" f" head_dim={self.head_dim},\n" f" dtype={self.dtype},\n" - f" ping_size={self.ping_size}, pong_size={self.pong_size},\n" + f" 三区域: decode_slot={self.decode_slot}, compute={self.compute_slots}, prefetch={self.prefetch_slots},\n" f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n" f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n" f")" @@ -742,4 +775,189 @@ class OffloadEngine: v = self.v_cache_gpu[layer_id, gpu_slots] k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) - return k, v \ No newline at end of file + return k, v + + # ========== 三区域 GPU Buffer 方法 ========== + + def load_to_compute(self, cpu_block_ids: List[int]) -> None: + """ + 异步加载CPU blocks到Compute区。 + + Args: + cpu_block_ids: 要加载的CPU block IDs列表 + """ + if not cpu_block_ids: + self.compute_ready.record(self.transfer_stream_main) + return + + num_to_load = min(len(cpu_block_ids), len(self.compute_slots)) + logger.debug(f"Compute load: CPU{cpu_block_ids[:num_to_load]} -> GPU compute slots {self.compute_slots[:num_to_load]}") + + with torch.cuda.stream(self.transfer_stream_main): + for i in range(num_to_load): + cpu_id = cpu_block_ids[i] + gpu_slot = self.compute_slots[i] + # 所有层一起复制 + self.k_cache_gpu[:, gpu_slot].copy_( + self.k_cache_cpu[:, cpu_id], non_blocking=True + ) + self.v_cache_gpu[:, gpu_slot].copy_( + self.v_cache_cpu[:, cpu_id], non_blocking=True + ) + self.compute_ready.record(self.transfer_stream_main) + + def load_to_prefetch(self, cpu_block_ids: List[int]) -> None: + """ + 异步加载CPU blocks到Prefetch区。 + + Args: + cpu_block_ids: 要加载的CPU block IDs列表 + """ + if not cpu_block_ids: + self.prefetch_ready.record(self.transfer_stream_main) + return + + num_to_load = min(len(cpu_block_ids), len(self.prefetch_slots)) + logger.debug(f"Prefetch load: CPU{cpu_block_ids[:num_to_load]} -> GPU prefetch slots {self.prefetch_slots[:num_to_load]}") + + with torch.cuda.stream(self.transfer_stream_main): + for i in range(num_to_load): + cpu_id = cpu_block_ids[i] + gpu_slot = self.prefetch_slots[i] + self.k_cache_gpu[:, gpu_slot].copy_( + self.k_cache_cpu[:, cpu_id], non_blocking=True + ) + self.v_cache_gpu[:, gpu_slot].copy_( + self.v_cache_cpu[:, cpu_id], non_blocking=True + ) + self.prefetch_ready.record(self.transfer_stream_main) + + def wait_compute(self) -> None: + """等待Compute区加载完成。""" + self.compute_stream.wait_event(self.compute_ready) + + def wait_prefetch(self) -> None: + """等待Prefetch区加载完成。""" + self.compute_stream.wait_event(self.prefetch_ready) + + def swap_compute_prefetch(self) -> None: + """交换Compute区和Prefetch区的角色。""" + self.compute_slots, self.prefetch_slots = self.prefetch_slots, self.compute_slots + # 同时更新旧的ping/pong slots以保持兼容 + self.ping_slots, self.pong_slots = self.pong_slots, self.ping_slots + + def offload_decode_slot(self, cpu_block_id: int) -> None: + """ + 将Decode区的KV offload到CPU。 + + Args: + cpu_block_id: 目标CPU block ID + """ + logger.debug(f"Decode offload: GPU[{self.decode_slot}] -> CPU[{cpu_block_id}]") + + with torch.cuda.stream(self.transfer_stream_main): + self.transfer_stream_main.wait_stream(self.compute_stream) + self.k_cache_cpu[:, cpu_block_id].copy_( + self.k_cache_gpu[:, self.decode_slot], non_blocking=True + ) + self.v_cache_cpu[:, cpu_block_id].copy_( + self.v_cache_gpu[:, self.decode_slot], non_blocking=True + ) + self.decode_offload_done.record(self.transfer_stream_main) + + def wait_decode_offload(self) -> None: + """等待Decode区offload完成。""" + self.compute_stream.wait_event(self.decode_offload_done) + + def get_kv_for_compute( + self, + layer_id: int, + num_blocks: int, + ) -> Tuple[Tensor, Tensor]: + """ + 获取Compute区中指定数量blocks的KV。 + + Args: + layer_id: 层ID + num_blocks: 需要的block数量 + + Returns: + (k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim] + """ + slots = self.compute_slots[:num_blocks] + k = self.k_cache_gpu[layer_id, slots] # [num_blocks, block_size, heads, dim] + v = self.v_cache_gpu[layer_id, slots] + # Reshape: [num_blocks, block_size, heads, dim] -> [1, num_blocks*block_size, heads, dim] + k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) + v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) + return k, v + + def get_kv_for_prefetch( + self, + layer_id: int, + num_blocks: int, + ) -> Tuple[Tensor, Tensor]: + """ + 获取Prefetch区中指定数量blocks的KV。 + + Args: + layer_id: 层ID + num_blocks: 需要的block数量 + + Returns: + (k_cache, v_cache),shape: [1, num_blocks * block_size, kv_heads, head_dim] + """ + slots = self.prefetch_slots[:num_blocks] + k = self.k_cache_gpu[layer_id, slots] + v = self.v_cache_gpu[layer_id, slots] + k = k.reshape(1, -1, self.num_kv_heads, self.head_dim) + v = v.reshape(1, -1, self.num_kv_heads, self.head_dim) + return k, v + + def get_kv_for_decode_slot( + self, + layer_id: int, + pos_in_block: int, + ) -> Tuple[Tensor, Tensor]: + """ + 获取Decode区指定位置的KV(用于decode时的新token)。 + + Args: + layer_id: 层ID + pos_in_block: token在block内的位置 (0 to block_size-1) + + Returns: + (k_cache, v_cache),shape: [1, 1, kv_heads, head_dim] + """ + k = self.k_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] # [1, heads, dim] + v = self.v_cache_gpu[layer_id, self.decode_slot, pos_in_block:pos_in_block+1] + k = k.unsqueeze(0) # [1, 1, heads, dim] + v = v.unsqueeze(0) + return k, v + + def offload_compute_to_cpu(self, cpu_block_ids: List[int]) -> None: + """ + 将Compute区的KV offload到CPU。 + + Args: + cpu_block_ids: 目标CPU block IDs列表 + """ + if not cpu_block_ids: + return + + num_to_offload = min(len(cpu_block_ids), len(self.compute_slots)) + logger.debug(f"Compute offload: GPU {self.compute_slots[:num_to_offload]} -> CPU{cpu_block_ids[:num_to_offload]}") + + with torch.cuda.stream(self.transfer_stream_main): + # 等待计算完成 + self.transfer_stream_main.wait_stream(self.compute_stream) + + for i in range(num_to_offload): + gpu_slot = self.compute_slots[i] + cpu_id = cpu_block_ids[i] + self.k_cache_cpu[:, cpu_id].copy_( + self.k_cache_gpu[:, gpu_slot], non_blocking=True + ) + self.v_cache_cpu[:, cpu_id].copy_( + self.v_cache_gpu[:, gpu_slot], non_blocking=True + ) \ No newline at end of file diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 133a71b..92b0fc0 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -1,3 +1,4 @@ +import logging import torch from torch import nn import triton @@ -6,6 +7,8 @@ import triton.language as tl from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache from nanovllm.utils.context import get_context +logger = logging.getLogger(__name__) + @triton.jit def store_kvcache_kernel( @@ -97,13 +100,16 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute attention with Ping-Pong dual buffer for chunked prefill. + Compute attention with 三区域 GPU buffer for chunked prefill. For chunked prefill: - 1. Load previous KV from CPU using Ping-Pong (if any previous chunks) + 1. Load previous KV from CPU using Compute/Prefetch区 (if any previous chunks) 2. Compute attention against previous KV chunks (no causal mask) 3. Compute attention against current chunk's KV (causal) 4. Merge all results using online softmax + + 三区域设计保证:当前chunk的KV在Compute区,previous KV从CPU加载到Prefetch区, + 不会发生写入和加载区域重叠的问题。 """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs @@ -116,7 +122,7 @@ class Attention(nn.Module): o_acc = None lse_acc = None - # Load previous KV from CPU using Ping-Pong + # Load previous KV from CPU using Compute/Prefetch区 # Note: context.offload_engine is actually HybridKVCacheManager kvcache_manager = context.offload_engine seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None @@ -127,41 +133,42 @@ class Attention(nn.Module): if cpu_block_table: offload_engine = kvcache_manager.offload_engine - ping_size = offload_engine.ping_size - num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size - current_buffer = "ping" + # 使用 Prefetch区 来加载 previous KV(不会与当前 Compute区 冲突) + prefetch_size = offload_engine.num_prefetch_blocks + num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size + use_compute = True # 交替使用 Compute区 和 Prefetch区 - # Prefetch first chunk to Ping buffer - first_chunk_end = min(ping_size, len(cpu_block_table)) + # 首先将 previous KV 加载到 Prefetch区 + # Only layer 0 triggers the load (loads ALL layers at once) + first_chunk_end = min(prefetch_size, len(cpu_block_table)) first_chunk_ids = cpu_block_table[:first_chunk_end] - offload_engine.load_to_ping(first_chunk_ids) + if self.layer_id == 0: + offload_engine.load_to_prefetch(first_chunk_ids) for chunk_idx in range(num_chunks): - start = chunk_idx * ping_size - end = min(start + ping_size, len(cpu_block_table)) + start = chunk_idx * prefetch_size + end = min(start + prefetch_size, len(cpu_block_table)) num_blocks_in_chunk = end - start - # Prefetch next chunk to OTHER buffer - if chunk_idx + 1 < num_chunks: + # Prefetch next chunk to other buffer (if exists) + # Only layer 0 triggers the load + if chunk_idx + 1 < num_chunks and self.layer_id == 0: next_start = end - next_end = min(next_start + ping_size, len(cpu_block_table)) + next_end = min(next_start + prefetch_size, len(cpu_block_table)) next_chunk_ids = cpu_block_table[next_start:next_end] - if current_buffer == "ping": - offload_engine.load_to_pong(next_chunk_ids) + if use_compute: + # 当前在 Prefetch区,下一个加载到 Compute区(如果有空间) + # 注意:Compute区 此时已写入当前chunk的KV,不能覆盖 + # 所以这里我们使用简单的同步策略:等待当前完成后再加载 + pass # 简化版本:不进行双缓冲,只用 Prefetch区 else: - offload_engine.load_to_ping(next_chunk_ids) + offload_engine.load_to_prefetch(next_chunk_ids) - # Wait for current buffer and get KV - if current_buffer == "ping": - offload_engine.wait_ping() - prev_k, prev_v = offload_engine.get_kv_for_ping_slots( - self.layer_id, num_blocks_in_chunk - ) - else: - offload_engine.wait_pong() - prev_k, prev_v = offload_engine.get_kv_for_pong_slots( - self.layer_id, num_blocks_in_chunk - ) + # Wait for Prefetch区 and get KV + offload_engine.wait_prefetch() + prev_k, prev_v = offload_engine.get_kv_for_prefetch( + self.layer_id, num_blocks_in_chunk + ) # Compute attention against this chunk (no causal mask) prev_o, prev_lse = flash_attn_with_lse( @@ -178,8 +185,12 @@ class Attention(nn.Module): else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - # Switch buffer - current_buffer = "pong" if current_buffer == "ping" else "ping" + # Load next chunk to Prefetch区 (if exists) + if chunk_idx + 1 < num_chunks and self.layer_id == 0: + next_start = end + next_end = min(next_start + prefetch_size, len(cpu_block_table)) + next_chunk_ids = cpu_block_table[next_start:next_end] + offload_engine.load_to_prefetch(next_chunk_ids) # Compute attention against current chunk's KV (with causal mask) current_o, current_lse = flash_attn_with_lse( @@ -207,13 +218,16 @@ class Attention(nn.Module): context, ) -> torch.Tensor: """ - Compute decode attention with Ping-Pong dual buffer. + Compute decode attention with 三区域 GPU buffer. - All KV is stored on CPU. Uses Ping-Pong buffers on GPU: - 1. Load first chunk to Ping buffer - 2. While computing on current buffer, prefetch next chunk to other buffer - 3. Alternate between Ping and Pong buffers - 4. Merge attention outputs using online softmax (LSE) + All KV is stored on CPU. Uses Compute区 buffer on GPU: + 1. Load chunk to Compute区 + 2. Compute attention + 3. Repeat for all chunks + 4. Finally, attend to Decode区 (slot 0) which contains the new token's KV + 5. Merge all attention outputs using online softmax (LSE) + + 关键:新token的KV在Decode区(slot 0),不会被Compute区的加载覆盖。 """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs @@ -227,51 +241,37 @@ class Attention(nn.Module): # Get all CPU blocks for this sequence cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq) + if self.layer_id == 0: + logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}") if not cpu_block_table: raise RuntimeError("Chunked decode attention failed: no CPU blocks available") - # Get the actual offload_engine for Ping-Pong operations + # Get the actual offload_engine for 三区域 operations offload_engine = kvcache_manager.offload_engine - # Calculate chunk info - ping_size = offload_engine.ping_size - num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size + # Calculate chunk info using Compute区 + compute_size = offload_engine.num_compute_blocks + num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size o_acc = None lse_acc = None - current_buffer = "ping" - - # Prefetch first chunk to Ping buffer (loads all layers at once) - first_chunk_end = min(ping_size, len(cpu_block_table)) - first_chunk_ids = cpu_block_table[:first_chunk_end] - offload_engine.load_to_ping(first_chunk_ids) for chunk_idx in range(num_chunks): - start = chunk_idx * ping_size - end = min(start + ping_size, len(cpu_block_table)) + start = chunk_idx * compute_size + end = min(start + compute_size, len(cpu_block_table)) num_blocks_in_chunk = end - start + chunk_ids = cpu_block_table[start:end] - # Prefetch next chunk to OTHER buffer (overlapped with current computation) - if chunk_idx + 1 < num_chunks: - next_start = end - next_end = min(next_start + ping_size, len(cpu_block_table)) - next_chunk_ids = cpu_block_table[next_start:next_end] - if current_buffer == "ping": - offload_engine.load_to_pong(next_chunk_ids) - else: - offload_engine.load_to_ping(next_chunk_ids) + # Load this chunk to Compute区 + # Only layer 0 triggers the load (loads ALL layers at once) + if self.layer_id == 0: + offload_engine.load_to_compute(chunk_ids) - # Wait for current buffer to be ready and get KV - if current_buffer == "ping": - offload_engine.wait_ping() - k_chunk, v_chunk = offload_engine.get_kv_for_ping_slots( - self.layer_id, num_blocks_in_chunk - ) - else: - offload_engine.wait_pong() - k_chunk, v_chunk = offload_engine.get_kv_for_pong_slots( - self.layer_id, num_blocks_in_chunk - ) + # Wait for Compute区 to be ready and get KV + offload_engine.wait_compute() + k_chunk, v_chunk = offload_engine.get_kv_for_compute( + self.layer_id, num_blocks_in_chunk + ) # Compute attention for this chunk o_chunk, lse_chunk = flash_attn_with_lse( @@ -286,8 +286,21 @@ class Attention(nn.Module): else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk) - # Switch buffer for next iteration - current_buffer = "pong" if current_buffer == "ping" else "ping" + # Now attend to Decode区 (contains the new token's KV) + # This is the token being decoded - only 1 token at position pos_in_block + pos_in_block = context.decode_pos_in_block + decode_k, decode_v = offload_engine.get_kv_for_decode_slot(self.layer_id, pos_in_block) + decode_o, decode_lse = flash_attn_with_lse( + q_batched, decode_k, decode_v, + softmax_scale=self.scale, + causal=False, + ) + + # Merge with accumulated + if o_acc is None: + o_acc = decode_o + else: + o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse) if o_acc is None: raise RuntimeError("Chunked decode attention failed: no KV available") diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index b6b09a4..f3d0b5e 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -27,6 +27,8 @@ class Context: prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list) # Current sequence being processed (for chunked prefill to load KV) chunked_seq: Any = None + # Position within block for decode (used for reading from Decode区) + decode_pos_in_block: int = 0 _CONTEXT = Context() @@ -50,6 +52,7 @@ def set_context( chunk_offset=0, offload_engine=None, chunked_seq=None, + decode_pos_in_block=0, ): global _CONTEXT _CONTEXT = Context( @@ -66,6 +69,7 @@ def set_context( chunk_offset=chunk_offset, offload_engine=offload_engine, chunked_seq=chunked_seq, + decode_pos_in_block=decode_pos_in_block, )