import pickle import torch import torch.distributed as dist from multiprocessing.synchronize import Event from multiprocessing.shared_memory import SharedMemory from nanovllm.config import Config from nanovllm.engine.sequence import Sequence from nanovllm.models.qwen3 import Qwen3ForCausalLM from nanovllm.layers.sampler import GreedySampler from nanovllm.utils.context import set_context, get_context, reset_context from nanovllm.utils.loader import load_model from nanovllm.utils.logger import get_logger from nanovllm.kvcache import create_kvcache_manager, KVCacheManager logger = get_logger("model_runner") class ModelRunner: def __init__(self, config: Config, rank: int, event: Event | list[Event]): self.config = config hf_config = config.hf_config self.block_size = config.kvcache_block_size self.enforce_eager = config.enforce_eager self.world_size = config.tensor_parallel_size self.rank = rank self.event = event dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank) torch.cuda.set_device(rank) default_dtype = torch.get_default_dtype() torch.set_default_dtype(hf_config.torch_dtype) torch.set_default_device("cuda") self.model = Qwen3ForCausalLM(hf_config) load_model(self.model, config.model) self.sampler = GreedySampler() self.warmup_model() self.allocate_kv_cache() if not self.enforce_eager: self.capture_cudagraph() torch.set_default_device("cpu") torch.set_default_dtype(default_dtype) if self.world_size > 1: if rank == 0: self.shm = SharedMemory(name="nanovllm", create=True, size=2**20) dist.barrier() else: dist.barrier() self.shm = SharedMemory(name="nanovllm") self.loop() def exit(self): if self.world_size > 1: self.shm.close() dist.barrier() if self.rank == 0: self.shm.unlink() if not self.enforce_eager: del self.graphs, self.graph_pool torch.cuda.synchronize() dist.destroy_process_group() def loop(self): while True: method_name, args = self.read_shm() self.call(method_name, *args) if method_name == "exit": break def read_shm(self): assert self.world_size > 1 and self.rank > 0 self.event.wait() n = int.from_bytes(self.shm.buf[0:4], "little") method_name, *args = pickle.loads(self.shm.buf[4:n+4]) self.event.clear() return method_name, args def write_shm(self, method_name, *args): assert self.world_size > 1 and self.rank == 0 data = pickle.dumps([method_name, *args]) n = len(data) self.shm.buf[0:4] = n.to_bytes(4, "little") self.shm.buf[4:n+4] = data for event in self.event: event.set() def call(self, method_name, *args): if self.world_size > 1 and self.rank == 0: self.write_shm(method_name, *args) method = getattr(self, method_name, None) return method(*args) def warmup_model(self): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() # Use a reasonable warmup length instead of max_model_len # Warmup only needs to trigger CUDA kernel JIT compilation # Using 2 blocks is sufficient and avoids huge memory allocation warmup_len = min(self.block_size * 2, self.config.max_model_len) warmup_len = max(warmup_len, 128) # At least 128 tokens num_seqs = min(self.config.max_num_batched_tokens // warmup_len, self.config.max_num_seqs, 4) num_seqs = max(num_seqs, 1) seqs = [Sequence([0] * warmup_len) for _ in range(num_seqs)] self.run(seqs, True) torch.cuda.empty_cache() def allocate_kv_cache(self): config = self.config hf_config = config.hf_config free, total = torch.cuda.mem_get_info() used = total - free peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] current = torch.cuda.memory_stats()["allocated_bytes.all.current"] num_kv_heads = hf_config.num_key_value_heads // self.world_size head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads) block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize # Calculate max GPU blocks based on available memory max_gpu_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes assert max_gpu_blocks > 0 # Determine final GPU blocks: user-specified or auto (max available) if config.num_gpu_blocks > 0: num_gpu_blocks = min(config.num_gpu_blocks, max_gpu_blocks) else: num_gpu_blocks = max_gpu_blocks if config.enable_cpu_offload: # Three-region design: CPU is primary storage, GPU is working buffer # CPU blocks = all blocks needed to support max_model_len (stores complete KV for one max sequence) # GPU blocks = three-region working buffer (user-specified or auto) num_cpu_blocks = (config.max_model_len + self.block_size - 1) // self.block_size config.num_gpu_kvcache_blocks = num_gpu_blocks config.num_cpu_kvcache_blocks = num_cpu_blocks # For backward compatibility config.num_kvcache_blocks = num_gpu_blocks + num_cpu_blocks else: config.num_kvcache_blocks = num_gpu_blocks config.num_gpu_kvcache_blocks = num_gpu_blocks config.num_cpu_kvcache_blocks = 0 # Create KV cache manager using factory self.kvcache_manager: KVCacheManager = create_kvcache_manager(config) # Allocate cache through manager self.kvcache_manager.allocate_cache( num_layers=hf_config.num_hidden_layers, num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=hf_config.torch_dtype, ) # Log KV cache allocation info with detailed per-token breakdown gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2) cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2) total_memory_mb = gpu_memory_mb + cpu_memory_mb # Calculate per-token KV cache usage # KV per token = 2 (K+V) * num_layers * kv_heads * head_dim * dtype_size dtype_size = 2 if hf_config.torch_dtype in [torch.float16, torch.bfloat16] else 4 per_token_kv_bytes = 2 * hf_config.num_hidden_layers * num_kv_heads * head_dim * dtype_size per_token_kv_kb = per_token_kv_bytes / 1024 logger.info( f"KV Cache per-token: {per_token_kv_kb:.2f}KB " f"(2 * {hf_config.num_hidden_layers}layers * {num_kv_heads}kv_heads * {head_dim}head_dim * {dtype_size}bytes)" ) logger.info( f"KV Cache per-block: {block_bytes / (1024**2):.2f}MB " f"({per_token_kv_kb:.2f}KB * {self.block_size}tokens)" ) if config.enable_cpu_offload: compute_size = config.num_gpu_kvcache_blocks // 2 tokens_per_chunk = compute_size * self.block_size logger.info( f"KV Cache allocated (Chunked Offload mode): " f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), " f"CPU={config.num_cpu_kvcache_blocks} blocks ({cpu_memory_mb:.1f}MB), " f"Total={total_memory_mb:.1f}MB" ) logger.info( f"Chunked Offload config: compute_size={compute_size} blocks, " f"tokens_per_chunk={tokens_per_chunk}, " f"block_size={self.block_size}" ) else: logger.info( f"KV Cache allocated: " f"GPU={config.num_gpu_kvcache_blocks} blocks ({gpu_memory_mb:.1f}MB), " f"block_size={self.block_size}" ) # Bind layer caches to attention modules and set layer_id layer_id = 0 for module in self.model.modules(): if hasattr(module, "k_cache") and hasattr(module, "v_cache"): k_cache, v_cache = self.kvcache_manager.get_layer_cache(layer_id) module.k_cache = k_cache module.v_cache = v_cache # Set layer_id for chunked prefill support if hasattr(module, "layer_id"): module.layer_id = layer_id layer_id += 1 def prepare_block_tables(self, seqs: list[Sequence]): max_len = max(len(seq.block_table) for seq in seqs) block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs] block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) return block_tables def prepare_prefill(self, seqs: list[Sequence], chunk_info: list[tuple] = None): """ Prepare inputs for prefill. Args: seqs: List of sequences to prefill chunk_info: Optional chunked prefill info from get_gpu_block_tables_partial(). If provided, only process blocks in the chunk. Format: [(gpu_block_ids, start_block_idx, end_block_idx), ...] """ # Check if any sequence has blocks (not warmup) has_blocks = any(seq.block_table for seq in seqs) gpu_block_tables = None if has_blocks and hasattr(self, 'kvcache_manager'): if chunk_info is None: # Standard prefill - try to get all blocks # This may fail if GPU doesn't have enough capacity self.kvcache_manager.prepare_for_attention(seqs, is_prefill=True) gpu_block_tables = self.kvcache_manager.get_gpu_block_tables(seqs) else: # Chunked prefill - use provided chunk info gpu_block_tables = [info[0] for info in chunk_info] input_ids = [] positions = [] cu_seqlens_q = [0] cu_seqlens_k = [0] max_seqlen_q = 0 max_seqlen_k = 0 slot_mapping = [] block_tables = None for seq_idx, seq in enumerate(seqs): if chunk_info is not None: # Chunked prefill: only process blocks in the chunk gpu_blocks, start_block_idx, end_block_idx = chunk_info[seq_idx] if not gpu_blocks: continue # Calculate token range for this chunk start_token = start_block_idx * self.block_size end_token = min(end_block_idx * self.block_size, len(seq)) if end_block_idx == seq.num_blocks: # Last chunk includes partial last block end_token = len(seq) # Input tokens for this chunk chunk_tokens = seq[start_token:end_token] input_ids.extend(chunk_tokens) positions.extend(list(range(start_token, end_token))) seqlen_q = end_token - start_token seqlen_k = end_token # Context includes all tokens up to this point cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) # Slot mapping for blocks in this chunk for i, gpu_block_id in enumerate(gpu_blocks): block_idx = start_block_idx + i start = gpu_block_id * self.block_size if block_idx != seq.num_blocks - 1: end = start + self.block_size else: end = start + seq.last_block_num_tokens slot_mapping.extend(list(range(start, end))) else: # Standard prefill seqlen = len(seq) input_ids.extend(seq[seq.num_cached_tokens:]) positions.extend(list(range(seq.num_cached_tokens, seqlen))) seqlen_q = seqlen - seq.num_cached_tokens seqlen_k = seqlen cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) if not seq.block_table: # warmup continue # Use GPU physical block IDs for slot mapping gpu_blocks = gpu_block_tables[seq_idx] for i in range(seq.num_cached_blocks, seq.num_blocks): start = gpu_blocks[i] * self.block_size if i != seq.num_blocks - 1: end = start + self.block_size else: end = start + seq.last_block_num_tokens slot_mapping.extend(list(range(start, end))) if cu_seqlens_k[-1] > cu_seqlens_q[-1] and gpu_block_tables: # prefix cache block_tables = self._prepare_gpu_block_tables(gpu_block_tables) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables) return input_ids, positions def prepare_decode(self, seqs: list[Sequence]): # Prepare KV cache (updates gather_indices for hybrid manager) if hasattr(self, 'kvcache_manager'): self.kvcache_manager.prepare_for_attention(seqs, is_prefill=False) # Get GPU physical block tables gpu_block_tables = self.kvcache_manager.get_gpu_block_tables(seqs) else: gpu_block_tables = [list(seq.block_table) for seq in seqs] input_ids = [] positions = [] slot_mapping = [] context_lens = [] for seq_idx, seq in enumerate(seqs): input_ids.append(seq.last_token) positions.append(len(seq) - 1) context_lens.append(len(seq)) # Use GPU physical block ID for slot mapping gpu_blocks = gpu_block_tables[seq_idx] slot_mapping.append(gpu_blocks[-1] * self.block_size + seq.last_block_num_tokens - 1) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) # Use GPU physical block tables for attention block_tables = self._prepare_gpu_block_tables(gpu_block_tables) set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) return input_ids, positions def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]): """Prepare block tables tensor from GPU physical block IDs.""" max_len = max(len(bt) for bt in gpu_block_tables) padded = [bt + [-1] * (max_len - len(bt)) for bt in gpu_block_tables] return torch.tensor(padded, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) def prepare_sample(self, seqs: list[Sequence]): temperatures = [] for seq in seqs: temperatures.append(seq.temperature) temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True) return temperatures @torch.inference_mode() def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): context = get_context() # Use eager mode for: prefill, enforce_eager, large batch, or chunked attention # Chunked attention requires dynamic KV loading that can't be captured in CUDA Graph use_eager = is_prefill or self.enforce_eager or input_ids.size(0) > 512 or context.is_chunked_prefill if use_eager: return self.model.compute_logits(self.model(input_ids, positions)) else: bs = input_ids.size(0) context = get_context() graph = self.graphs[next(x for x in self.graph_bs if x >= bs)] graph_vars = self.graph_vars graph_vars["input_ids"][:bs] = input_ids graph_vars["positions"][:bs] = positions graph_vars["slot_mapping"].fill_(-1) graph_vars["slot_mapping"][:bs] = context.slot_mapping graph_vars["context_lens"].zero_() graph_vars["context_lens"][:bs] = context.context_lens graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables graph.replay() return self.model.compute_logits(graph_vars["outputs"][:bs]) def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: # Check if Chunked Offload mode should be used (all blocks on CPU) if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'): use_chunked_offload = self._should_use_chunked_offload(seqs, is_prefill) if use_chunked_offload: if is_prefill: return self.run_chunked_offload_prefill(seqs) else: return self.run_chunked_offload_decode(seqs) # Check if chunked prefill is needed (legacy path) if is_prefill and hasattr(self, 'kvcache_manager'): needs_chunked = any( hasattr(self.kvcache_manager, 'needs_chunked_prefill') and self.kvcache_manager.needs_chunked_prefill(seq) for seq in seqs if seq.block_table ) if needs_chunked: return self.run_chunked_prefill(seqs) # Check if chunked decode is needed (legacy path) if not is_prefill and hasattr(self, 'kvcache_manager'): needs_chunked = any( hasattr(self.kvcache_manager, 'needs_chunked_decode') and self.kvcache_manager.needs_chunked_decode(seq) for seq in seqs if seq.block_table ) if needs_chunked: return self.run_chunked_decode(seqs) input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) temperatures = self.prepare_sample(seqs) if self.rank == 0 else None logits = self.run_model(input_ids, positions, is_prefill) token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None reset_context() return token_ids def _should_use_chunked_offload(self, seqs: list[Sequence], is_prefill: bool) -> bool: """ Check if three-region mode should be used. Use three-region when: - CPU offload is enabled - There are blocks on CPU (either allocated there or offloaded) - Sequence exceeds GPU Compute region capacity """ if not hasattr(self.kvcache_manager, 'offload_engine'): return False for seq in seqs: if not seq.block_table: continue # Skip warmup sequences # Check if any blocks are on CPU cpu_blocks, _ = self.kvcache_manager.get_all_cpu_blocks(seq) if cpu_blocks: # Has CPU blocks - use three-region return True # Check if sequence needs more blocks than GPU Compute region can hold compute_size = self.kvcache_manager.offload_engine.num_compute_blocks if seq.num_blocks > compute_size: # Needs chunked processing return True return False def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]: """ Run prefill in chunks when sequences exceed GPU capacity. For each chunk: 1. Process tokens through model forward pass 2. At each attention layer: - Load previous KV from CPU (handled by attention layer) - Compute attention with online softmax merging - Store current KV to GPU cache 3. After chunk completes, offload KV to CPU 4. Load next chunk's blocks to GPU """ import sys # Currently only supporting single sequence for chunked prefill assert len(seqs) == 1, "Chunked prefill only supports single sequence" seq = seqs[0] total_blocks = seq.num_blocks print(f"[Chunked Prefill] Starting: {total_blocks} total blocks, " f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr) chunk_num = 0 logits = None while True: # Get chunk info (which blocks are on GPU and not yet prefilled) chunk_info = self.kvcache_manager.get_gpu_block_tables_partial(seqs) gpu_blocks, start_block_idx, end_block_idx = chunk_info[0] if not gpu_blocks: # No more blocks to process break chunk_num += 1 chunk_tokens = (end_block_idx - start_block_idx) * self.block_size if end_block_idx == seq.num_blocks: # Last block may be partial chunk_tokens = len(seq) - start_block_idx * self.block_size print(f"[Chunked Prefill] Chunk {chunk_num}: blocks {start_block_idx}-{end_block_idx-1}, " f"~{chunk_tokens} tokens", file=sys.stderr) # Prepare inputs for this chunk input_ids, positions = self._prepare_chunked_prefill(seq, gpu_blocks, start_block_idx, end_block_idx) if input_ids.numel() == 0: print(f"[Chunked Prefill] No input tokens, breaking", file=sys.stderr) break print(f"[Chunked Prefill] Running model with {input_ids.numel()} tokens...", file=sys.stderr) # Run model forward pass logits = self.run_model(input_ids, positions, is_prefill=True) reset_context() print(f"[Chunked Prefill] Model forward complete", file=sys.stderr) # Check if this is the last chunk # Mark current chunk as prefilled and offload to CPU self.kvcache_manager.complete_prefill_chunk(seq) # Check if more chunks needed if not self.kvcache_manager.needs_chunked_prefill(seq): print(f"[Chunked Prefill] All chunks done, sampling", file=sys.stderr) break print(f"[Chunked Prefill] Chunk transfer complete, loading next...", file=sys.stderr) # Sample from the last chunk's logits temperatures = self.prepare_sample(seqs) if self.rank == 0 else None if logits is not None: token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None else: token_ids = [0] if self.rank == 0 else None return token_ids def run_chunked_decode(self, seqs: list[Sequence]) -> list[int]: """ Run decode with chunked attention when sequence exceeds GPU capacity. For decode, we need attention over ALL previous tokens. With CPU offload, we load KV chunks and compute attention incrementally per-layer. Flow: 1. Ensure last block is on GPU (for writing new KV token) 2. Run model forward - each attention layer: a. Compute attention on GPU blocks b. Load CPU blocks in chunks, compute + merge 3. Sample from output """ # Currently only supporting single sequence for chunked decode assert len(seqs) == 1, "Chunked decode only supports single sequence" seq = seqs[0] # Prepare inputs input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) # Ensure last block is on GPU for writing new KV token last_gpu_slot = self.kvcache_manager.ensure_last_block_on_gpu(seq) slot = last_gpu_slot * self.block_size + seq.last_block_num_tokens - 1 slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) # Set up context for chunked decode set_context( is_prefill=False, # Decode mode slot_mapping=slot_mapping, context_lens=context_len, is_chunked_prefill=True, # Use chunked attention path kvcache_manager=self.kvcache_manager, chunked_seq=seq, ) # Run model forward pass # Each attention layer will handle chunked KV loading internally logits = self.run_model(input_ids, positions, is_prefill=False) reset_context() # Sample temperatures = self.prepare_sample(seqs) if self.rank == 0 else None token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None return token_ids def _prepare_chunked_prefill( self, seq: Sequence, gpu_blocks: list[int], start_block_idx: int, end_block_idx: int, ) -> tuple[torch.Tensor, torch.Tensor]: """ Prepare inputs for a single chunk in chunked prefill. Sets up context with is_chunked_prefill=True so attention layers know to load previous KV from CPU. """ # Calculate token range for this chunk start_token = start_block_idx * self.block_size end_token = min(end_block_idx * self.block_size, len(seq)) # Input tokens for this chunk input_ids = seq[start_token:end_token] positions = list(range(start_token, end_token)) # Slot mapping for storing KV cache slot_mapping = [] for i, gpu_block_id in enumerate(gpu_blocks): block_idx = start_block_idx + i start = gpu_block_id * self.block_size if block_idx != seq.num_blocks - 1: end = start + self.block_size else: end = start + seq.last_block_num_tokens slot_mapping.extend(list(range(start, end))) # Trim slot_mapping to match actual token count actual_tokens = end_token - start_token slot_mapping = slot_mapping[:actual_tokens] # Convert to tensors input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) # Set up context for chunked prefill seqlen = actual_tokens cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) set_context( is_prefill=True, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=seqlen, max_seqlen_k=seqlen, slot_mapping=slot_mapping, is_chunked_prefill=True, kvcache_manager=self.kvcache_manager, # Pass manager for loading previous KV chunked_seq=seq, # Pass sequence for loading previous KV ) return input_ids, positions def run_chunked_offload_prefill(self, seqs: list[Sequence]) -> list[int]: """ Run prefill with unified ring buffer (CPU is primary storage). Flow: 1. All blocks are allocated to CPU (primary storage) 2. Each chunk writes KV to ring buffer slot[chunk_idx % N] 3. After each chunk, offload from ring buffer slot to CPU 4. All N-1 other slots are used to load previous chunks for attention """ import sys assert len(seqs) == 1, "Ring buffer prefill only supports single sequence" seq = seqs[0] offload_engine = self.kvcache_manager.offload_engine # Each chunk uses 1 ring buffer slot = 1 block tokens_per_chunk = self.block_size total_tokens = len(seq) num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, " f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, " f"total_chunks={num_chunks}", file=sys.stderr) chunk_idx = 0 logits = None processed_tokens = 0 # Get CPU block table for offload targets cpu_block_ids, logical_ids = self.kvcache_manager.get_all_cpu_blocks(seq) while processed_tokens < total_tokens: chunk_start = processed_tokens chunk_end = min(processed_tokens + tokens_per_chunk, total_tokens) # Get ring buffer slot for this chunk write_slot = offload_engine.get_write_slot_for_prefill(chunk_idx) # CPU block index for this chunk block_idx = chunk_idx print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, " f"write_slot={write_slot}", file=sys.stderr) # Prepare inputs input_ids, positions = self._prepare_chunked_offload_chunk( seq, chunk_start, chunk_end, write_slot, block_idx, chunk_idx ) if input_ids.numel() == 0: break # Run model forward logits = self.run_model(input_ids, positions, is_prefill=True) reset_context() # Mark block as prefilled if block_idx < len(seq.block_table): logical_id = seq.block_table[block_idx] self.kvcache_manager.prefilled_blocks.add(logical_id) # Offload this chunk's ring buffer slot to CPU (async) if block_idx < len(cpu_block_ids): cpu_block_id = cpu_block_ids[block_idx] # Call sparse policy hook before offload (to capture metadata) sparse_policy = self.kvcache_manager.sparse_policy if sparse_policy is not None: num_tokens = chunk_end - chunk_start for layer_id in range(offload_engine.num_layers): k_cache = offload_engine.k_cache_gpu[layer_id, write_slot, :num_tokens] sparse_policy.on_block_offloaded( cpu_block_id=cpu_block_id, layer_id=layer_id, k_cache=k_cache, num_valid_tokens=num_tokens, ) offload_engine.offload_slot_to_cpu(write_slot, cpu_block_id) # Wait for offload to complete before next chunk # (slot will be reused after N chunks) offload_engine.wait_slot_offload(write_slot) processed_tokens = chunk_end chunk_idx += 1 # Wait for all offloads to complete offload_engine.wait_all_offload_done() print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr) # Sample from last logits temperatures = self.prepare_sample(seqs) if self.rank == 0 else None if logits is not None: token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None else: token_ids = [0] if self.rank == 0 else None return token_ids def _prepare_chunked_offload_chunk( self, seq: Sequence, chunk_start: int, chunk_end: int, write_slot: int, block_idx: int, chunk_idx: int, ) -> tuple[torch.Tensor, torch.Tensor]: """Prepare inputs for a chunked offload prefill chunk (ring buffer design).""" # Input tokens for this chunk input_ids = seq[chunk_start:chunk_end] positions = list(range(chunk_start, chunk_end)) # Create slot mapping pointing to the single write_slot slot_mapping = [] for pos in range(chunk_start, chunk_end): pos_in_block = pos % self.block_size slot = write_slot * self.block_size + pos_in_block slot_mapping.append(slot) # Convert to tensors num_tokens = chunk_end - chunk_start input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) # Set up context for chunked prefill seqlen = num_tokens cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) set_context( is_prefill=True, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=seqlen, max_seqlen_k=seqlen, slot_mapping=slot_mapping, is_chunked_prefill=True, kvcache_manager=self.kvcache_manager, chunked_seq=seq, current_chunk_idx=chunk_idx, # Pass chunk index for ring buffer pipeline ) return input_ids, positions def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]: """ Run decode with ring buffer (CPU is primary storage). All KV is on CPU. Uses decode_slot (slot[0]) to write new KV. Other slots (slots[1:]) are used to load previous KV chunks via pipeline. New token's KV is written to decode_slot then offloaded to CPU only when block is full. Key: decode_slot is dedicated to writing new KV, never used for loading. Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens. """ assert len(seqs) == 1, "Ring buffer decode only supports single sequence" seq = seqs[0] offload_engine = self.kvcache_manager.offload_engine # Prepare inputs input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) # Use Decode region (slot 0) to write new KV decode_slot = offload_engine.decode_slot # = 0 pos_in_block = (len(seq) - 1) % self.block_size slot = decode_slot * self.block_size + pos_in_block slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) # Get decode start position for accumulated token tracking decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq) # Set up context for chunked decode set_context( is_prefill=False, slot_mapping=slot_mapping, context_lens=context_len, is_chunked_prefill=True, # Use chunked attention path kvcache_manager=self.kvcache_manager, chunked_seq=seq, decode_pos_in_block=pos_in_block, decode_start_pos_in_block=decode_start_pos, ) # Run model forward pass logits = self.run_model(input_ids, positions, is_prefill=False) reset_context() # Only offload when block is full (pos_in_block == block_size - 1) # This avoids unnecessary offloading on every decode step if pos_in_block == self.block_size - 1: last_cpu_block = self.kvcache_manager.get_last_cpu_block(seq) if last_cpu_block >= 0: offload_engine.offload_decode_slot(last_cpu_block) offload_engine.wait_all_offload_done() # Reset decode start position for next block self.kvcache_manager.reset_decode_start_pos(seq) # Sample temperatures = self.prepare_sample(seqs) if self.rank == 0 else None token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None return token_ids @torch.inference_mode() def capture_cudagraph(self): config = self.config hf_config = config.hf_config max_bs = min(self.config.max_num_seqs, 512) max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size input_ids = torch.zeros(max_bs, dtype=torch.int64) positions = torch.zeros(max_bs, dtype=torch.int64) slot_mapping = torch.zeros(max_bs, dtype=torch.int32) context_lens = torch.zeros(max_bs, dtype=torch.int32) block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32) outputs = torch.zeros(max_bs, hf_config.hidden_size) self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16)) self.graphs = {} self.graph_pool = None for bs in reversed(self.graph_bs): graph = torch.cuda.CUDAGraph() set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup with torch.cuda.graph(graph, self.graph_pool): outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[bs] = graph torch.cuda.synchronize() reset_context() self.graph_vars = dict( input_ids=input_ids, positions=positions, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables, outputs=outputs, )