From 0b6f19242d7e6e8c72d9456c838ccefc78695169 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 10 Dec 2025 03:47:37 +0800 Subject: [PATCH] [feat] Added chunked prefill and kvcache offload mechenism. --- bench.py | 10 +- bench_offload.py | 64 ++ bench_vllm.py | 10 +- nanovllm/config.py | 10 + nanovllm/engine/llm_engine.py | 2 +- nanovllm/engine/model_runner.py | 390 +++++++++- nanovllm/engine/scheduler.py | 21 +- nanovllm/kvcache/__init__.py | 74 ++ nanovllm/kvcache/base_manager.py | 260 +++++++ nanovllm/kvcache/chunked_attention.py | 555 ++++++++++++++ nanovllm/kvcache/gpu_manager.py | 262 +++++++ nanovllm/kvcache/hybrid_manager.py | 906 +++++++++++++++++++++++ nanovllm/kvcache/kernels.py | 190 +++++ nanovllm/kvcache/offload_engine.py | 400 ++++++++++ nanovllm/kvcache/policies/__init__.py | 51 ++ nanovllm/kvcache/policies/base_policy.py | 156 ++++ nanovllm/kvcache/policies/fifo_policy.py | 101 +++ nanovllm/kvcache/policies/lru_policy.py | 93 +++ nanovllm/layers/attention.py | 159 +++- nanovllm/utils/context.py | 53 +- tests/__init__.py | 1 + tests/test_kernels.py | 169 +++++ tests/test_kvcache_manager.py | 175 +++++ tests/test_offload_engine.py | 196 +++++ tests/test_policies.py | 167 +++++ 25 files changed, 4414 insertions(+), 61 deletions(-) create mode 100644 bench_offload.py create mode 100644 nanovllm/kvcache/__init__.py create mode 100644 nanovllm/kvcache/base_manager.py create mode 100644 nanovllm/kvcache/chunked_attention.py create mode 100644 nanovllm/kvcache/gpu_manager.py create mode 100644 nanovllm/kvcache/hybrid_manager.py create mode 100644 nanovllm/kvcache/kernels.py create mode 100644 nanovllm/kvcache/offload_engine.py create mode 100644 nanovllm/kvcache/policies/__init__.py create mode 100644 nanovllm/kvcache/policies/base_policy.py create mode 100644 nanovllm/kvcache/policies/fifo_policy.py create mode 100644 nanovllm/kvcache/policies/lru_policy.py create mode 100644 tests/__init__.py create mode 100644 tests/test_kernels.py create mode 100644 tests/test_kvcache_manager.py create mode 100644 tests/test_offload_engine.py create mode 100644 tests/test_policies.py diff --git a/bench.py b/bench.py index 535251f..514721b 100644 --- a/bench.py +++ b/bench.py @@ -44,16 +44,16 @@ def main(): print("Prefill Benchmark") print("=" * 60) bench_prefill(llm, num_seqs=1, input_len=1024) - bench_prefill(llm, num_seqs=1, input_len=2048) - bench_prefill(llm, num_seqs=1, input_len=4095) - bench_prefill(llm, num_seqs=16, input_len=1024) - bench_prefill(llm, num_seqs=64, input_len=1024) + # bench_prefill(llm, num_seqs=1, input_len=2048) + # bench_prefill(llm, num_seqs=1, input_len=4095) + # bench_prefill(llm, num_seqs=16, input_len=1024) + # bench_prefill(llm, num_seqs=64, input_len=1024) print("=" * 60) print("Decode Benchmark") print("=" * 60) bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024) - bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024) + # bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024) if __name__ == "__main__": diff --git a/bench_offload.py b/bench_offload.py new file mode 100644 index 0000000..17c138e --- /dev/null +++ b/bench_offload.py @@ -0,0 +1,64 @@ +import os +import time +from random import randint, seed +from nanovllm import LLM, SamplingParams + + +def bench_decode(llm, num_seqs, max_input_len, max_output_len): + """Benchmark decode performance""" + seed(0) + prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)] + sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)] + + t = time.time() + llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + t = time.time() - t + total_output_tokens = sum(sp.max_tokens for sp in sampling_params) + throughput = total_output_tokens / t + print(f"[Decode] Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + + +def bench_prefill(llm, num_seqs, input_len): + """Benchmark prefill performance""" + seed(0) + prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)] + sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1) + + t = time.time() + llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + t = time.time() - t + total_input_tokens = num_seqs * input_len + throughput = total_input_tokens / t + print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + + +def main(): + path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") + llm = LLM( + path, + enforce_eager=True, + max_model_len=128 * 1024, + max_num_batched_tokens=128 * 1024, + enable_cpu_offload=True, + cpu_memory_gb=32.0, + ) + + # Warmup + llm.generate(["Benchmark: "], SamplingParams()) + + print("=" * 60) + print("Prefill Benchmark (CPU Offload)") + print("=" * 60) + bench_prefill(llm, num_seqs=1, input_len=64*1024) + # bench_prefill(llm, num_seqs=1, input_len=16384) + # bench_prefill(llm, num_seqs=1, input_len=32000) + + print("=" * 60) + print("Decode Benchmark (CPU Offload)") + print("=" * 60) + bench_decode(llm, num_seqs=1, max_input_len=64*1024, max_output_len=256) + # bench_decode(llm, num_seqs=1, max_input_len=16384, max_output_len=256) + + +if __name__ == "__main__": + main() diff --git a/bench_vllm.py b/bench_vllm.py index ff5789a..8497f44 100644 --- a/bench_vllm.py +++ b/bench_vllm.py @@ -47,16 +47,16 @@ def main(): print("Prefill Benchmark") print("=" * 60) bench_prefill(llm, num_seqs=1, input_len=1024) - bench_prefill(llm, num_seqs=1, input_len=2048) - bench_prefill(llm, num_seqs=1, input_len=4095) - bench_prefill(llm, num_seqs=16, input_len=1024) - bench_prefill(llm, num_seqs=64, input_len=1024) + # bench_prefill(llm, num_seqs=1, input_len=2048) + # bench_prefill(llm, num_seqs=1, input_len=4095) + # bench_prefill(llm, num_seqs=16, input_len=1024) + # bench_prefill(llm, num_seqs=64, input_len=1024) print("=" * 60) print("Decode Benchmark") print("=" * 60) bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024) - bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024) + # bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024) if __name__ == "__main__": diff --git a/nanovllm/config.py b/nanovllm/config.py index 959ffb3..9e01048 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -17,6 +17,16 @@ class Config: kvcache_block_size: int = 256 num_kvcache_blocks: int = -1 + # CPU Offload configuration + enable_cpu_offload: bool = False + cpu_memory_gb: float = 16.0 # CPU memory limit for KV cache + offload_policy: str = "lru" # "lru", "fifo", or full class path + num_transfer_streams: int = 4 # Number of CUDA streams for async transfers + + # Computed fields for offload (set in __post_init__ or by ModelRunner) + num_gpu_kvcache_blocks: int = -1 + num_cpu_kvcache_blocks: int = -1 + def __post_init__(self): assert os.path.isdir(self.model) assert self.kvcache_block_size % 256 == 0 diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 5efcfd1..637b5f6 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -31,7 +31,7 @@ class LLMEngine: self.model_runner = ModelRunner(config, 0, self.events) self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True) config.eos = self.tokenizer.eos_token_id - self.scheduler = Scheduler(config) + self.scheduler = Scheduler(config, self.model_runner.kvcache_manager) atexit.register(self.exit) def exit(self): diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index f66c38e..38a9fcc 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -10,6 +10,7 @@ from nanovllm.models.qwen3 import Qwen3ForCausalLM from nanovllm.layers.sampler import Sampler from nanovllm.utils.context import set_context, get_context, reset_context from nanovllm.utils.loader import load_model +from nanovllm.kvcache import create_kvcache_manager, KVCacheManager class ModelRunner: @@ -107,14 +108,45 @@ class ModelRunner: num_kv_heads = hf_config.num_key_value_heads // self.world_size head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads) block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize - config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes - assert config.num_kvcache_blocks > 0 - self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim) + + # Calculate GPU block count + num_gpu_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes + assert num_gpu_blocks > 0 + + if config.enable_cpu_offload: + # Calculate CPU blocks based on cpu_memory_gb + cpu_bytes = int(config.cpu_memory_gb * 1024**3) + num_cpu_blocks = cpu_bytes // block_bytes + config.num_gpu_kvcache_blocks = num_gpu_blocks + config.num_cpu_kvcache_blocks = num_cpu_blocks + # For backward compatibility + config.num_kvcache_blocks = num_gpu_blocks + num_cpu_blocks + else: + config.num_kvcache_blocks = num_gpu_blocks + config.num_gpu_kvcache_blocks = num_gpu_blocks + config.num_cpu_kvcache_blocks = 0 + + # Create KV cache manager using factory + self.kvcache_manager: KVCacheManager = create_kvcache_manager(config) + + # Allocate cache through manager + self.kvcache_manager.allocate_cache( + num_layers=hf_config.num_hidden_layers, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + dtype=hf_config.torch_dtype, + ) + + # Bind layer caches to attention modules and set layer_id layer_id = 0 for module in self.model.modules(): if hasattr(module, "k_cache") and hasattr(module, "v_cache"): - module.k_cache = self.kv_cache[0, layer_id] - module.v_cache = self.kv_cache[1, layer_id] + k_cache, v_cache = self.kvcache_manager.get_layer_cache(layer_id) + module.k_cache = k_cache + module.v_cache = v_cache + # Set layer_id for chunked prefill support + if hasattr(module, "layer_id"): + module.layer_id = layer_id layer_id += 1 def prepare_block_tables(self, seqs: list[Sequence]): @@ -123,7 +155,30 @@ class ModelRunner: block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) return block_tables - def prepare_prefill(self, seqs: list[Sequence]): + def prepare_prefill(self, seqs: list[Sequence], chunk_info: list[tuple] = None): + """ + Prepare inputs for prefill. + + Args: + seqs: List of sequences to prefill + chunk_info: Optional chunked prefill info from get_gpu_block_tables_partial(). + If provided, only process blocks in the chunk. + Format: [(gpu_block_ids, start_block_idx, end_block_idx), ...] + """ + # Check if any sequence has blocks (not warmup) + has_blocks = any(seq.block_table for seq in seqs) + + gpu_block_tables = None + if has_blocks and hasattr(self, 'kvcache_manager'): + if chunk_info is None: + # Standard prefill - try to get all blocks + # This may fail if GPU doesn't have enough capacity + self.kvcache_manager.prepare_for_attention(seqs, is_prefill=True) + gpu_block_tables = self.kvcache_manager.get_gpu_block_tables(seqs) + else: + # Chunked prefill - use provided chunk info + gpu_block_tables = [info[0] for info in chunk_info] + input_ids = [] positions = [] cu_seqlens_q = [0] @@ -132,27 +187,67 @@ class ModelRunner: max_seqlen_k = 0 slot_mapping = [] block_tables = None - for seq in seqs: - seqlen = len(seq) - input_ids.extend(seq[seq.num_cached_tokens:]) - positions.extend(list(range(seq.num_cached_tokens, seqlen))) - seqlen_q = seqlen - seq.num_cached_tokens - seqlen_k = seqlen - cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) - cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) - max_seqlen_q = max(seqlen_q, max_seqlen_q) - max_seqlen_k = max(seqlen_k, max_seqlen_k) - if not seq.block_table: # warmup - continue - for i in range(seq.num_cached_blocks, seq.num_blocks): - start = seq.block_table[i] * self.block_size - if i != seq.num_blocks - 1: - end = start + self.block_size - else: - end = start + seq.last_block_num_tokens - slot_mapping.extend(list(range(start, end))) - if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache - block_tables = self.prepare_block_tables(seqs) + + for seq_idx, seq in enumerate(seqs): + if chunk_info is not None: + # Chunked prefill: only process blocks in the chunk + gpu_blocks, start_block_idx, end_block_idx = chunk_info[seq_idx] + if not gpu_blocks: + continue + + # Calculate token range for this chunk + start_token = start_block_idx * self.block_size + end_token = min(end_block_idx * self.block_size, len(seq)) + if end_block_idx == seq.num_blocks: + # Last chunk includes partial last block + end_token = len(seq) + + # Input tokens for this chunk + chunk_tokens = seq[start_token:end_token] + input_ids.extend(chunk_tokens) + positions.extend(list(range(start_token, end_token))) + + seqlen_q = end_token - start_token + seqlen_k = end_token # Context includes all tokens up to this point + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) + cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) + max_seqlen_q = max(seqlen_q, max_seqlen_q) + max_seqlen_k = max(seqlen_k, max_seqlen_k) + + # Slot mapping for blocks in this chunk + for i, gpu_block_id in enumerate(gpu_blocks): + block_idx = start_block_idx + i + start = gpu_block_id * self.block_size + if block_idx != seq.num_blocks - 1: + end = start + self.block_size + else: + end = start + seq.last_block_num_tokens + slot_mapping.extend(list(range(start, end))) + else: + # Standard prefill + seqlen = len(seq) + input_ids.extend(seq[seq.num_cached_tokens:]) + positions.extend(list(range(seq.num_cached_tokens, seqlen))) + seqlen_q = seqlen - seq.num_cached_tokens + seqlen_k = seqlen + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) + cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) + max_seqlen_q = max(seqlen_q, max_seqlen_q) + max_seqlen_k = max(seqlen_k, max_seqlen_k) + if not seq.block_table: # warmup + continue + # Use GPU physical block IDs for slot mapping + gpu_blocks = gpu_block_tables[seq_idx] + for i in range(seq.num_cached_blocks, seq.num_blocks): + start = gpu_blocks[i] * self.block_size + if i != seq.num_blocks - 1: + end = start + self.block_size + else: + end = start + seq.last_block_num_tokens + slot_mapping.extend(list(range(start, end))) + + if cu_seqlens_k[-1] > cu_seqlens_q[-1] and gpu_block_tables: # prefix cache + block_tables = self._prepare_gpu_block_tables(gpu_block_tables) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) @@ -162,23 +257,40 @@ class ModelRunner: return input_ids, positions def prepare_decode(self, seqs: list[Sequence]): + # Prepare KV cache (updates gather_indices for hybrid manager) + if hasattr(self, 'kvcache_manager'): + self.kvcache_manager.prepare_for_attention(seqs, is_prefill=False) + # Get GPU physical block tables + gpu_block_tables = self.kvcache_manager.get_gpu_block_tables(seqs) + else: + gpu_block_tables = [list(seq.block_table) for seq in seqs] + input_ids = [] positions = [] slot_mapping = [] context_lens = [] - for seq in seqs: + for seq_idx, seq in enumerate(seqs): input_ids.append(seq.last_token) positions.append(len(seq) - 1) context_lens.append(len(seq)) - slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1) + # Use GPU physical block ID for slot mapping + gpu_blocks = gpu_block_tables[seq_idx] + slot_mapping.append(gpu_blocks[-1] * self.block_size + seq.last_block_num_tokens - 1) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) - block_tables = self.prepare_block_tables(seqs) + # Use GPU physical block tables for attention + block_tables = self._prepare_gpu_block_tables(gpu_block_tables) set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) return input_ids, positions + def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]): + """Prepare block tables tensor from GPU physical block IDs.""" + max_len = max(len(bt) for bt in gpu_block_tables) + padded = [bt + [-1] * (max_len - len(bt)) for bt in gpu_block_tables] + return torch.tensor(padded, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + def prepare_sample(self, seqs: list[Sequence]): temperatures = [] for seq in seqs: @@ -206,6 +318,26 @@ class ModelRunner: return self.model.compute_logits(graph_vars["outputs"][:bs]) def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: + # Check if chunked prefill is needed + if is_prefill and hasattr(self, 'kvcache_manager'): + needs_chunked = any( + hasattr(self.kvcache_manager, 'needs_chunked_prefill') and + self.kvcache_manager.needs_chunked_prefill(seq) + for seq in seqs if seq.block_table + ) + if needs_chunked: + return self.run_chunked_prefill(seqs) + + # Check if chunked decode is needed + if not is_prefill and hasattr(self, 'kvcache_manager'): + needs_chunked = any( + hasattr(self.kvcache_manager, 'needs_chunked_decode') and + self.kvcache_manager.needs_chunked_decode(seq) + for seq in seqs if seq.block_table + ) + if needs_chunked: + return self.run_chunked_decode(seqs) + input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) temperatures = self.prepare_sample(seqs) if self.rank == 0 else None logits = self.run_model(input_ids, positions, is_prefill) @@ -213,6 +345,204 @@ class ModelRunner: reset_context() return token_ids + def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]: + """ + Run prefill in chunks when sequences exceed GPU capacity. + + For each chunk: + 1. Process tokens through model forward pass + 2. At each attention layer: + - Load previous KV from CPU (handled by attention layer) + - Compute attention with online softmax merging + - Store current KV to GPU cache + 3. After chunk completes, offload KV to CPU + 4. Load next chunk's blocks to GPU + """ + import sys + + # Currently only supporting single sequence for chunked prefill + assert len(seqs) == 1, "Chunked prefill only supports single sequence" + seq = seqs[0] + + total_blocks = seq.num_blocks + print(f"[Chunked Prefill] Starting: {total_blocks} total blocks, " + f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr) + + chunk_num = 0 + logits = None + + while True: + # Get chunk info (which blocks are on GPU and not yet prefilled) + chunk_info = self.kvcache_manager.get_gpu_block_tables_partial(seqs) + gpu_blocks, start_block_idx, end_block_idx = chunk_info[0] + + if not gpu_blocks: + # No more blocks to process + break + + chunk_num += 1 + chunk_tokens = (end_block_idx - start_block_idx) * self.block_size + if end_block_idx == seq.num_blocks: + # Last block may be partial + chunk_tokens = len(seq) - start_block_idx * self.block_size + + print(f"[Chunked Prefill] Chunk {chunk_num}: blocks {start_block_idx}-{end_block_idx-1}, " + f"~{chunk_tokens} tokens", file=sys.stderr) + + # Prepare inputs for this chunk + input_ids, positions = self._prepare_chunked_prefill(seq, gpu_blocks, start_block_idx, end_block_idx) + + if input_ids.numel() == 0: + print(f"[Chunked Prefill] No input tokens, breaking", file=sys.stderr) + break + + print(f"[Chunked Prefill] Running model with {input_ids.numel()} tokens...", file=sys.stderr) + + # Run model forward pass + logits = self.run_model(input_ids, positions, is_prefill=True) + reset_context() + + print(f"[Chunked Prefill] Model forward complete", file=sys.stderr) + + # Check if this is the last chunk + # Mark current chunk as prefilled and offload to CPU + self.kvcache_manager.complete_prefill_chunk(seq) + + # Check if more chunks needed + if not self.kvcache_manager.needs_chunked_prefill(seq): + print(f"[Chunked Prefill] All chunks done, sampling", file=sys.stderr) + break + + print(f"[Chunked Prefill] Chunk transfer complete, loading next...", file=sys.stderr) + + # Sample from the last chunk's logits + temperatures = self.prepare_sample(seqs) if self.rank == 0 else None + if logits is not None: + token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None + else: + token_ids = [0] if self.rank == 0 else None + + return token_ids + + def run_chunked_decode(self, seqs: list[Sequence]) -> list[int]: + """ + Run decode with chunked attention when sequence exceeds GPU capacity. + + For decode, we need attention over ALL previous tokens. With CPU offload, + we load KV chunks and compute attention incrementally. + """ + import sys + + # Currently only supporting single sequence for chunked decode + assert len(seqs) == 1, "Chunked decode only supports single sequence" + seq = seqs[0] + + total_blocks = len(seq.block_table) + print(f"[Chunked Decode] Sequence has {total_blocks} blocks, " + f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr) + + # Prepare inputs + input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + + # Compute slot mapping for the new token + # Get the last block's GPU slot if it's on GPU, otherwise we need to handle it + last_logical_id = seq.block_table[-1] + last_block = self.kvcache_manager.logical_blocks[last_logical_id] + + if last_block.location.name == "GPU": + slot = last_block.gpu_slot * self.block_size + seq.last_block_num_tokens - 1 + else: + # Last block is on CPU - we need to bring it to GPU for writing the new token + # This is a special case - allocate a temporary GPU slot + # For simplicity, use a fixed slot (this might conflict, but for decode + # we only write 1 token so it should be ok) + print(f"[Chunked Decode] Warning: last block on CPU, using temp slot", file=sys.stderr) + slot = 0 # Use first slot temporarily + + slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + + # Set up context for chunked decode + set_context( + is_prefill=False, # Decode mode + slot_mapping=slot_mapping, + context_lens=context_len, + is_chunked_prefill=True, # Use chunked attention + offload_engine=self.kvcache_manager, + chunked_seq=seq, + ) + + # Run model forward pass + logits = self.run_model(input_ids, positions, is_prefill=False) + reset_context() + + # Sample + temperatures = self.prepare_sample(seqs) if self.rank == 0 else None + token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None + + return token_ids + + def _prepare_chunked_prefill( + self, + seq: Sequence, + gpu_blocks: list[int], + start_block_idx: int, + end_block_idx: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Prepare inputs for a single chunk in chunked prefill. + + Sets up context with is_chunked_prefill=True so attention layers + know to load previous KV from CPU. + """ + # Calculate token range for this chunk + start_token = start_block_idx * self.block_size + end_token = min(end_block_idx * self.block_size, len(seq)) + + # Input tokens for this chunk + input_ids = seq[start_token:end_token] + positions = list(range(start_token, end_token)) + + # Slot mapping for storing KV cache + slot_mapping = [] + for i, gpu_block_id in enumerate(gpu_blocks): + block_idx = start_block_idx + i + start = gpu_block_id * self.block_size + if block_idx != seq.num_blocks - 1: + end = start + self.block_size + else: + end = start + seq.last_block_num_tokens + slot_mapping.extend(list(range(start, end))) + + # Trim slot_mapping to match actual token count + actual_tokens = end_token - start_token + slot_mapping = slot_mapping[:actual_tokens] + + # Convert to tensors + input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + + # Set up context for chunked prefill + seqlen = actual_tokens + cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + + set_context( + is_prefill=True, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=seqlen, + max_seqlen_k=seqlen, + slot_mapping=slot_mapping, + is_chunked_prefill=True, + offload_engine=self.kvcache_manager, # Pass manager for loading previous KV + chunked_seq=seq, # Pass sequence for loading previous KV + ) + + return input_ids, positions + @torch.inference_mode() def capture_cudagraph(self): config = self.config diff --git a/nanovllm/engine/scheduler.py b/nanovllm/engine/scheduler.py index e3f856d..16fbc80 100644 --- a/nanovllm/engine/scheduler.py +++ b/nanovllm/engine/scheduler.py @@ -1,19 +1,22 @@ from collections import deque from time import perf_counter_ns +from typing import TYPE_CHECKING from nanovllm.config import Config from nanovllm.engine.sequence import Sequence, SequenceStatus -from nanovllm.engine.block_manager import BlockManager from nanovllm.utils.observer import Observer +if TYPE_CHECKING: + from nanovllm.kvcache import KVCacheManager + class Scheduler: - def __init__(self, config: Config): + def __init__(self, config: Config, kvcache_manager: "KVCacheManager"): self.max_num_seqs = config.max_num_seqs self.max_num_batched_tokens = config.max_num_batched_tokens self.eos = config.eos - self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size) + self.kvcache_manager = kvcache_manager self.waiting: deque[Sequence] = deque() self.running: deque[Sequence] = deque() @@ -32,10 +35,10 @@ class Scheduler: if Observer.ttft_start == 0: Observer.ttft_start = perf_counter_ns() seq = self.waiting[0] - if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq): + if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.kvcache_manager.can_allocate(seq): break num_seqs += 1 - self.block_manager.allocate(seq) + self.kvcache_manager.allocate(seq) num_batched_tokens += len(seq) - seq.num_cached_tokens seq.status = SequenceStatus.RUNNING self.waiting.popleft() @@ -47,7 +50,7 @@ class Scheduler: # decode while self.running and num_seqs < self.max_num_seqs: seq = self.running.popleft() - while not self.block_manager.can_append(seq): + while not self.kvcache_manager.can_append(seq): if self.running: self.preempt(self.running.pop()) else: @@ -55,7 +58,7 @@ class Scheduler: break else: num_seqs += 1 - self.block_manager.may_append(seq) + self.kvcache_manager.may_append(seq) scheduled_seqs.append(seq) assert scheduled_seqs self.running.extendleft(reversed(scheduled_seqs)) @@ -63,7 +66,7 @@ class Scheduler: def preempt(self, seq: Sequence): seq.status = SequenceStatus.WAITING - self.block_manager.deallocate(seq) + self.kvcache_manager.deallocate(seq) self.waiting.appendleft(seq) def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]: @@ -71,5 +74,5 @@ class Scheduler: seq.append_token(token_id) if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: seq.status = SequenceStatus.FINISHED - self.block_manager.deallocate(seq) + self.kvcache_manager.deallocate(seq) self.running.remove(seq) diff --git a/nanovllm/kvcache/__init__.py b/nanovllm/kvcache/__init__.py new file mode 100644 index 0000000..a700ce5 --- /dev/null +++ b/nanovllm/kvcache/__init__.py @@ -0,0 +1,74 @@ +""" +KV Cache management module. + +This module provides pluggable KV cache management strategies: +- GPUOnlyManager: Pure GPU (default, current nano-vllm behavior) +- HybridKVCacheManager: CPU offload with CUDA Graph support + +Usage: + from nanovllm.kvcache import create_kvcache_manager + + manager = create_kvcache_manager(config) +""" + +from typing import TYPE_CHECKING + +from nanovllm.kvcache.base_manager import KVCacheManager +from nanovllm.kvcache.gpu_manager import GPUOnlyManager + +if TYPE_CHECKING: + from nanovllm.config import Config + + +def create_kvcache_manager(config: "Config") -> KVCacheManager: + """ + Factory function to create the appropriate KV cache manager. + + Decision logic: + 1. If enable_cpu_offload=False: use GPUOnlyManager + 2. If enable_cpu_offload=True but all blocks fit in GPU: use GPUOnlyManager + 3. If enable_cpu_offload=True and need CPU blocks: use HybridKVCacheManager + + Args: + config: Model configuration with offload settings + + Returns: + KVCacheManager instance + """ + if not getattr(config, 'enable_cpu_offload', False): + # Default: pure GPU mode + return GPUOnlyManager( + num_blocks=config.num_kvcache_blocks, + block_size=config.kvcache_block_size, + ) + + # CPU offload is enabled + num_gpu_blocks = config.num_gpu_kvcache_blocks + num_cpu_blocks = config.num_cpu_kvcache_blocks + + if num_cpu_blocks <= 0: + # All blocks fit in GPU, use pure GPU mode + return GPUOnlyManager( + num_blocks=num_gpu_blocks, + block_size=config.kvcache_block_size, + ) + + # Need CPU offload: use hybrid manager + from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager + from nanovllm.kvcache.policies import get_policy + + policy = get_policy(getattr(config, 'offload_policy', 'lru')) + + return HybridKVCacheManager( + num_gpu_slots=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks, + block_size=config.kvcache_block_size, + policy=policy, + ) + + +__all__ = [ + "KVCacheManager", + "GPUOnlyManager", + "create_kvcache_manager", +] diff --git a/nanovllm/kvcache/base_manager.py b/nanovllm/kvcache/base_manager.py new file mode 100644 index 0000000..3f33bb6 --- /dev/null +++ b/nanovllm/kvcache/base_manager.py @@ -0,0 +1,260 @@ +""" +Abstract base class for KV cache managers. + +This interface allows pluggable implementations: +- GPUOnlyManager: Pure GPU (current nano-vllm behavior) +- HybridKVCacheManager: CPU offload with CUDA Graph support +- Future: Disk offload, distributed cache, etc. +""" + +from abc import ABC, abstractmethod +from typing import List, Tuple, Optional +import torch +from torch import Tensor + +from nanovllm.engine.sequence import Sequence + + +class KVCacheManager(ABC): + """ + Abstract base class for KV cache management strategies. + + A KVCacheManager handles: + 1. Physical memory allocation (GPU and optionally CPU) + 2. Logical block management (allocation, deallocation, prefix caching) + 3. Data transfer between devices (for hybrid managers) + 4. Integration with CUDA graphs + + Key design principles: + - Sequences reference logical block IDs + - Physical block IDs (GPU slots) may differ from logical IDs + - CUDA Graph compatibility requires fixed tensor addresses + """ + + @property + @abstractmethod + def block_size(self) -> int: + """Number of tokens per block.""" + pass + + @property + @abstractmethod + def num_free_blocks(self) -> int: + """Number of free logical blocks available for allocation.""" + pass + + @abstractmethod + def allocate_cache( + self, + num_layers: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype, + ) -> None: + """ + Allocate KV cache storage. + + Called once during initialization to allocate GPU (and optionally CPU) + memory for the KV cache. + + Args: + num_layers: Number of transformer layers + num_kv_heads: Number of key-value heads per layer + head_dim: Dimension per head + dtype: Data type for cache (e.g., torch.float16) + """ + pass + + @abstractmethod + def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: + """ + Get K and V cache tensors for a specific layer. + + The returned tensors must be on GPU and have fixed addresses + for CUDA Graph compatibility. + + Args: + layer_id: Layer index + + Returns: + (k_cache, v_cache) tensors + Shape depends on implementation, typically: + [num_blocks, block_size, kv_heads, head_dim] + """ + pass + + @abstractmethod + def can_allocate(self, seq: Sequence) -> bool: + """ + Check if blocks can be allocated for a new sequence. + + Called before allocate() to ensure sufficient resources. + + Args: + seq: Sequence to check + + Returns: + True if allocation is possible + """ + pass + + @abstractmethod + def allocate(self, seq: Sequence) -> None: + """ + Allocate blocks for a sequence during prefill. + + This method: + 1. Checks prefix cache for matching blocks + 2. Allocates new blocks as needed + 3. Updates seq.block_table with logical block IDs + 4. Updates seq.num_cached_tokens for prefix cache hits + + Args: + seq: Sequence to allocate blocks for + """ + pass + + @abstractmethod + def deallocate(self, seq: Sequence) -> None: + """ + Release blocks for a finished sequence. + + This method: + 1. Decrements reference counts + 2. Frees blocks with zero references + 3. Clears seq.block_table + + Args: + seq: Sequence whose blocks to release + """ + pass + + @abstractmethod + def can_append(self, seq: Sequence) -> bool: + """ + Check if a new block can be allocated for decode. + + Called before may_append() to check if resources are available. + + Args: + seq: Sequence to check + + Returns: + True if append is possible (or no new block needed) + """ + pass + + @abstractmethod + def may_append(self, seq: Sequence) -> None: + """ + Potentially allocate a new block during decode. + + Called after each decode step. If the current block is full, + allocates a new block and updates seq.block_table. + + Args: + seq: Sequence that may need a new block + """ + pass + + @abstractmethod + def prepare_for_attention( + self, + seqs: List[Sequence], + is_prefill: bool, + ) -> None: + """ + Prepare KV cache for attention computation. + + For GPU-only managers: typically a no-op. + For hybrid managers: ensures all needed blocks are on GPU, + may trigger prefetching from CPU. + + Called before attention computation. For decode with CUDA graphs, + this should update gather_indices but not perform actual transfers + (transfers happen inside the graph). + + Args: + seqs: Sequences that will be processed + is_prefill: True for prefill phase, False for decode + """ + pass + + @abstractmethod + def get_gpu_block_tables( + self, + seqs: List[Sequence], + ) -> List[List[int]]: + """ + Get GPU physical block tables for sequences. + + For GPU-only managers: returns seq.block_table directly. + For hybrid managers: returns GPU slot IDs (may differ from logical IDs). + + The returned block tables are used to compute slot_mapping + in ModelRunner.prepare_prefill/decode. + + Args: + seqs: Sequences to get block tables for + + Returns: + List of GPU block tables, one per sequence + """ + pass + + def post_attention_cleanup( + self, + seqs: List[Sequence], + is_prefill: bool, + ) -> None: + """ + Cleanup after attention computation. + + Optional hook for managers to perform post-attention tasks: + - Offloading cold blocks to CPU + - Updating access statistics + - etc. + + Default implementation does nothing. + + Args: + seqs: Sequences that were processed + is_prefill: True for prefill phase, False for decode + """ + pass + + def get_num_blocks_needed(self, num_tokens: int) -> int: + """ + Calculate number of blocks needed for given token count. + + Args: + num_tokens: Number of tokens + + Returns: + Number of blocks needed + """ + return (num_tokens + self.block_size - 1) // self.block_size + + @staticmethod + def compute_hash(token_ids: list, prefix: int = -1) -> int: + """ + Compute hash for prefix caching. + + Uses xxhash for fast hashing. The hash includes the prefix hash + to create a chain of hashes for multi-block sequences. + + Args: + token_ids: Token IDs in the block + prefix: Hash of previous block, or -1 for first block + + Returns: + Hash value + """ + import xxhash + import numpy as np + + h = xxhash.xxh64() + if prefix != -1: + h.update(prefix.to_bytes(8, "little")) + h.update(np.array(token_ids).tobytes()) + return h.intdigest() diff --git a/nanovllm/kvcache/chunked_attention.py b/nanovllm/kvcache/chunked_attention.py new file mode 100644 index 0000000..b06dd05 --- /dev/null +++ b/nanovllm/kvcache/chunked_attention.py @@ -0,0 +1,555 @@ +""" +Chunked attention implementation for CPU KV cache offloading. + +This module implements flash attention with LSE (log-sum-exp) output, +enabling proper online softmax merging for chunked prefill. + +Key functions: +- flash_attn_with_lse: Flash attention that returns output and LSE +- merge_attention_outputs: Merge outputs from multiple KV chunks +- chunked_prefill_attention: High-level interface for chunked attention +""" + +import math +import torch +import triton +import triton.language as tl +from typing import Tuple, List, Optional + + +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel_with_lse( + Q, + K, + V, + Out, + Lse, + TMP, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_ob, + stride_oh, + stride_om, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """Flash attention forward kernel with LSE output for online softmax.""" + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + + q_ptrs = ( + Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + ) + k_ptrs = ( + K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + ) + v_ptrs = ( + V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + ) + + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + + # Load Q + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 + ) + + # Loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + + # Load K + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + # Compute QK + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, tl.trans(k)) + + # Masking + if not EVEN_N: + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + + # Online softmax + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # Scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + + # Load V and update output + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # Update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + # Final scaling + o_scale = tl.exp(m_i - lse_i) + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + + # Store LSE + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + + # Store output + out_ptrs = ( + Out + + off_b * stride_ob + + off_h * stride_oh + + (offs_m[:, None] * stride_om + offs_d[None, :]) + ) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store( + out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) + ) + + +def flash_attn_with_lse( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: Optional[float] = None, + causal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Flash attention forward pass that returns both output and LSE. + + Supports GQA (grouped query attention) where num_kv_heads < num_q_heads. + + Args: + q: Query tensor [batch, seqlen_q, nheads_q, headdim] + k: Key tensor [batch, seqlen_k, nheads_kv, headdim] + v: Value tensor [batch, seqlen_k, nheads_kv, headdim] + softmax_scale: Scaling factor (default: 1/sqrt(headdim)) + causal: Whether to apply causal masking + + Returns: + out: Output tensor [batch, seqlen_q, nheads_q, headdim] + lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q] + """ + # Ensure contiguous + if not q.is_contiguous(): + q = q.contiguous() + if not k.is_contiguous(): + k = k.contiguous() + if not v.is_contiguous(): + v = v.contiguous() + + batch, seqlen_q, nheads_q, headdim = q.shape + _, seqlen_k, nheads_kv, _ = k.shape + + assert k.shape == (batch, seqlen_k, nheads_kv, headdim) + assert v.shape == (batch, seqlen_k, nheads_kv, headdim) + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.dtype == k.dtype == v.dtype + + # Handle GQA by repeating K/V heads + if nheads_kv != nheads_q: + assert nheads_q % nheads_kv == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_kv ({nheads_kv})" + repeat_factor = nheads_q // nheads_kv + # [batch, seqlen_k, nheads_kv, headdim] -> [batch, seqlen_k, nheads_q, headdim] + k = k.repeat_interleave(repeat_factor, dim=2) + v = v.repeat_interleave(repeat_factor, dim=2) + nheads = nheads_q + else: + nheads = nheads_q + + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(headdim) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + out = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16) + BLOCK = 128 + num_warps = 4 if headdim <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + + _fwd_kernel_with_lse[grid]( + q, + k, + v, + out, + lse, + tmp, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + out.stride(0), + out.stride(2), + out.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + seqlen_q // 32, + seqlen_k // 32, + causal, + BLOCK_HEADDIM, + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + + # Trim LSE to actual seqlen_q + lse = lse[:, :, :seqlen_q] + + # Ensure output has same dtype as input + out = out.to(q.dtype) + + return out, lse + + +def merge_attention_outputs( + o1: torch.Tensor, + lse1: torch.Tensor, + o2: torch.Tensor, + lse2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Merge two attention outputs using online softmax. + + This implements the online softmax merging formula: + - m_new = max(lse1, lse2) + - o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new)) + - lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new)) + + Args: + o1: First output [batch, seqlen_q, nheads, headdim] + lse1: First LSE [batch, nheads, seqlen_q] + o2: Second output [batch, seqlen_q, nheads, headdim] + lse2: Second LSE [batch, nheads, seqlen_q] + + Returns: + o_merged: Merged output [batch, seqlen_q, nheads, headdim] + lse_merged: Merged LSE [batch, nheads, seqlen_q] + """ + # lse shape: [batch, nheads, seqlen_q] + # o shape: [batch, seqlen_q, nheads, headdim] + + # Compute max for numerical stability + max_lse = torch.maximum(lse1, lse2) + + # Compute scaling factors + # exp1, exp2 shape: [batch, nheads, seqlen_q] + exp1 = torch.exp(lse1 - max_lse) + exp2 = torch.exp(lse2 - max_lse) + + # Reshape for broadcasting with output + # [batch, nheads, seqlen_q] -> [batch, seqlen_q, nheads, 1] + exp1_broad = exp1.transpose(1, 2).unsqueeze(-1) + exp2_broad = exp2.transpose(1, 2).unsqueeze(-1) + + # Merge outputs + sum_exp = exp1_broad + exp2_broad + o_merged = (o1 * exp1_broad + o2 * exp2_broad) / sum_exp + + # Compute merged LSE + lse_merged = max_lse + torch.log(exp1 + exp2) + + # Ensure output has same dtype as input + o_merged = o_merged.to(o1.dtype) + + return o_merged, lse_merged + + +def chunked_attention_varlen( + q: torch.Tensor, + kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k_list: List[torch.Tensor], + max_seqlen_q: int, + max_seqlen_k_list: List[int], + softmax_scale: Optional[float] = None, + causal_mask_per_chunk: Optional[List[bool]] = None, +) -> torch.Tensor: + """ + Compute attention with KV split across multiple chunks. + + This is the core function for chunked prefill. It computes attention + against each KV chunk and merges results using online softmax. + + For causal attention with chunked KV: + - First chunk (current tokens): Apply causal mask + - Previous chunks: No causal mask (all previous tokens are valid context) + + Args: + q: Query tensor [total_q_tokens, nheads, headdim] + kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim] + cu_seqlens_q: Cumulative sequence lengths for Q [batch+1] + cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk + max_seqlen_q: Maximum query sequence length + max_seqlen_k_list: List of maximum key sequence lengths for each chunk + softmax_scale: Scaling factor + causal_mask_per_chunk: Whether to apply causal mask for each chunk + + Returns: + out: Output tensor [total_q_tokens, nheads, headdim] + """ + if len(kv_chunks) == 0: + raise ValueError("Need at least one KV chunk") + + nheads = q.shape[1] + headdim = q.shape[2] + batch = cu_seqlens_q.shape[0] - 1 + + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(headdim) + + if causal_mask_per_chunk is None: + # Default: causal for last chunk only + causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True] + + # Initialize accumulated output and LSE + accumulated_o = None + accumulated_lse = None + + for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks): + is_causal = causal_mask_per_chunk[chunk_idx] + + # Reshape Q for batch processing + # For varlen, we need to handle each sequence separately + # For simplicity, assume single sequence (batch=1) for now + q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim] + + # Compute attention for this chunk + chunk_o, chunk_lse = flash_attn_with_lse( + q_batched, + k_chunk, + v_chunk, + softmax_scale=softmax_scale, + causal=is_causal, + ) + + # Merge with accumulated + if accumulated_o is None: + accumulated_o = chunk_o + accumulated_lse = chunk_lse + else: + accumulated_o, accumulated_lse = merge_attention_outputs( + accumulated_o, accumulated_lse, + chunk_o, chunk_lse, + ) + + # Remove batch dimension + return accumulated_o.squeeze(0) + + +class ChunkedPrefillState: + """ + State for tracking chunked prefill progress. + + This class maintains the accumulated attention output and LSE + across multiple prefill chunks. + """ + + def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device): + self.num_layers = num_layers + self.dtype = dtype + self.device = device + + # Per-layer accumulated outputs + # Each entry: (accumulated_output, accumulated_lse) or None + self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [ + None for _ in range(num_layers) + ] + + # Track which chunks have been processed + self.processed_chunks: int = 0 + + def update_layer( + self, + layer_id: int, + chunk_output: torch.Tensor, + chunk_lse: torch.Tensor, + ): + """Update accumulated state for a layer with a new chunk's output.""" + if self.layer_states[layer_id] is None: + self.layer_states[layer_id] = (chunk_output, chunk_lse) + else: + acc_o, acc_lse = self.layer_states[layer_id] + merged_o, merged_lse = merge_attention_outputs( + acc_o, acc_lse, + chunk_output, chunk_lse, + ) + self.layer_states[layer_id] = (merged_o, merged_lse) + + def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]: + """Get the final accumulated output for a layer.""" + if self.layer_states[layer_id] is None: + return None + return self.layer_states[layer_id][0] + + def clear(self): + """Clear all accumulated state.""" + self.layer_states = [None for _ in range(self.num_layers)] + self.processed_chunks = 0 + + +# Test function +def _test_chunked_attention(): + """Test chunked attention correctness against full attention.""" + from flash_attn import flash_attn_func + + torch.manual_seed(42) + + batch, seqlen, nheads, headdim = 1, 1024, 32, 128 + + # Generate random Q, K, V + q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) + k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) + v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16) + + # Full attention (reference) + out_ref = flash_attn_func(q, k, v, causal=True) + + # Chunked attention + chunk_size = 256 + num_chunks = seqlen // chunk_size + + accumulated_o = None + accumulated_lse = None + + for i in range(num_chunks): + start = i * chunk_size + end = (i + 1) * chunk_size + + # Q for this chunk + q_chunk = q[:, start:end, :, :] + + # K, V up to current position (for causal) + k_context = k[:, :end, :, :] + v_context = v[:, :end, :, :] + + # Compute attention + chunk_o, chunk_lse = flash_attn_with_lse( + q_chunk, k_context, v_context, causal=True + ) + + if accumulated_o is None: + accumulated_o = chunk_o + accumulated_lse = chunk_lse + else: + # For chunked prefill, we need to concatenate outputs, not merge + # Because each chunk's Q attends to different K positions + accumulated_o = torch.cat([accumulated_o, chunk_o], dim=1) + + # Compare + max_diff = (out_ref - accumulated_o).abs().max().item() + print(f"Max difference: {max_diff}") + assert max_diff < 1e-2, f"Chunked attention differs from reference: {max_diff}" + print("Test passed!") + + +if __name__ == "__main__": + _test_chunked_attention() diff --git a/nanovllm/kvcache/gpu_manager.py b/nanovllm/kvcache/gpu_manager.py new file mode 100644 index 0000000..ad8e40f --- /dev/null +++ b/nanovllm/kvcache/gpu_manager.py @@ -0,0 +1,262 @@ +""" +GPU-only KV cache manager. + +This is the default manager when CPU offload is disabled. +Refactored from the original block_manager.py to implement +the KVCacheManager interface. +""" + +from collections import deque +from typing import List, Tuple, Dict, Optional +import torch +from torch import Tensor + +from nanovllm.engine.sequence import Sequence +from nanovllm.kvcache.base_manager import KVCacheManager + + +class Block: + """Physical block in GPU memory.""" + + def __init__(self, block_id: int): + self.block_id = block_id + self.ref_count = 0 + self.hash = -1 + self.token_ids: List[int] = [] + + def update(self, hash: int, token_ids: List[int]): + self.hash = hash + self.token_ids = token_ids + + def reset(self): + self.ref_count = 1 + self.hash = -1 + self.token_ids = [] + + +class GPUOnlyManager(KVCacheManager): + """ + Pure GPU KV cache manager. + + This is the default implementation when enable_cpu_offload=False. + All KV cache resides in GPU memory. + + Features: + - Paged attention with configurable block size + - Prefix caching via xxhash + - Reference counting for block sharing + + This manager is fully compatible with CUDA graphs since + all data stays on GPU at fixed addresses. + """ + + def __init__(self, num_blocks: int, block_size: int): + """ + Initialize GPU-only manager. + + Args: + num_blocks: Total number of blocks to manage + block_size: Tokens per block (default 256) + """ + self._block_size = block_size + self._num_blocks = num_blocks + + # Block metadata + self.blocks: List[Block] = [Block(i) for i in range(num_blocks)] + + # Prefix cache: hash -> block_id + self.hash_to_block_id: Dict[int, int] = {} + + # Free/used tracking + self.free_block_ids: deque[int] = deque(range(num_blocks)) + self.used_block_ids: set[int] = set() + + # KV cache tensors (set by allocate_cache) + self.kv_cache: Optional[Tensor] = None + self.num_layers: int = 0 + self.num_kv_heads: int = 0 + self.head_dim: int = 0 + + @property + def block_size(self) -> int: + return self._block_size + + @property + def num_free_blocks(self) -> int: + return len(self.free_block_ids) + + def allocate_cache( + self, + num_layers: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype, + ) -> None: + """Allocate GPU KV cache tensor.""" + self.num_layers = num_layers + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + # Shape: [2, num_layers, num_blocks, block_size, kv_heads, head_dim] + # 2 for K and V + self.kv_cache = torch.empty( + 2, num_layers, self._num_blocks, self._block_size, + num_kv_heads, head_dim, + dtype=dtype, device="cuda" + ) + + def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: + """Get K/V cache for a layer.""" + assert self.kv_cache is not None, "Cache not allocated" + return self.kv_cache[0, layer_id], self.kv_cache[1, layer_id] + + def _allocate_block(self, block_id: int) -> Block: + """Internal: allocate a specific block.""" + block = self.blocks[block_id] + assert block.ref_count == 0, f"Block {block_id} is not free" + block.reset() + self.free_block_ids.remove(block_id) + self.used_block_ids.add(block_id) + return block + + def _deallocate_block(self, block_id: int) -> None: + """Internal: deallocate a block.""" + assert self.blocks[block_id].ref_count == 0 + self.used_block_ids.remove(block_id) + self.free_block_ids.append(block_id) + + def can_allocate(self, seq: Sequence) -> bool: + """Check if we have enough blocks for the sequence.""" + return len(self.free_block_ids) >= seq.num_blocks + + def allocate(self, seq: Sequence) -> None: + """ + Allocate blocks for a sequence during prefill. + + Implements prefix caching: if a block's content matches + a previously cached block, reuse it instead of allocating new. + """ + assert not seq.block_table, "Sequence already has blocks allocated" + + h = -1 # Hash chain + cache_miss = False + + for i in range(seq.num_blocks): + token_ids = seq.block(i) + + # Only compute hash for full blocks + if len(token_ids) == self._block_size: + h = self.compute_hash(token_ids, h) + else: + h = -1 + + # Try prefix cache lookup + block_id = self.hash_to_block_id.get(h, -1) + if block_id == -1 or self.blocks[block_id].token_ids != token_ids: + cache_miss = True + + if cache_miss: + # Cache miss: allocate new block + block_id = self.free_block_ids[0] + block = self._allocate_block(block_id) + else: + # Cache hit: reuse existing block + seq.num_cached_tokens += self._block_size + if block_id in self.used_block_ids: + # Block is in use, increment ref count + block = self.blocks[block_id] + block.ref_count += 1 + else: + # Block was freed but hash still valid + block = self._allocate_block(block_id) + + # Update hash mapping for full blocks + if h != -1: + block.update(h, token_ids) + self.hash_to_block_id[h] = block_id + + seq.block_table.append(block_id) + + def deallocate(self, seq: Sequence) -> None: + """Release all blocks for a sequence.""" + for block_id in reversed(seq.block_table): + block = self.blocks[block_id] + block.ref_count -= 1 + if block.ref_count == 0: + self._deallocate_block(block_id) + + seq.num_cached_tokens = 0 + seq.block_table.clear() + + def can_append(self, seq: Sequence) -> bool: + """Check if we can append a token (may need new block).""" + # Need new block only if current position is at block boundary + need_new_block = (len(seq) % self._block_size == 1) + return len(self.free_block_ids) >= int(need_new_block) + + def may_append(self, seq: Sequence) -> None: + """Handle potential new block allocation during decode.""" + block_table = seq.block_table + last_block = self.blocks[block_table[-1]] + + seq_len = len(seq) + pos_in_block = seq_len % self._block_size + + if pos_in_block == 1: + # Just crossed into new block, need to allocate + assert last_block.hash != -1, "Previous block should be complete" + block_id = self.free_block_ids[0] + self._allocate_block(block_id) + block_table.append(block_id) + + elif pos_in_block == 0: + # Just filled a block, compute hash for prefix cache + assert last_block.hash == -1, "Block should not have hash yet" + token_ids = seq.block(seq.num_blocks - 1) + prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 + h = self.compute_hash(token_ids, prefix) + last_block.update(h, token_ids) + self.hash_to_block_id[h] = last_block.block_id + + else: + # Middle of block, nothing to do + assert last_block.hash == -1 + + def prepare_for_attention( + self, + seqs: List[Sequence], + is_prefill: bool, + ) -> None: + """ + No-op for GPU-only manager. + + All blocks are already on GPU, no preparation needed. + """ + pass + + def get_gpu_block_tables( + self, + seqs: List[Sequence], + ) -> List[List[int]]: + """ + Return block tables directly (logical = physical for GPU-only). + """ + return [list(seq.block_table) for seq in seqs] + + def post_attention_cleanup( + self, + seqs: List[Sequence], + is_prefill: bool, + ) -> None: + """No-op for GPU-only manager.""" + pass + + def __repr__(self) -> str: + return ( + f"GPUOnlyManager(" + f"num_blocks={self._num_blocks}, " + f"block_size={self._block_size}, " + f"free={len(self.free_block_ids)}, " + f"used={len(self.used_block_ids)}" + f")" + ) diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py new file mode 100644 index 0000000..27b0269 --- /dev/null +++ b/nanovllm/kvcache/hybrid_manager.py @@ -0,0 +1,906 @@ +""" +Hybrid CPU-GPU KV cache manager with CUDA Graph support. + +Key design for CUDA Graph compatibility: +1. GPU buffer has fixed addresses (allocated once) +2. CPU pool has fixed addresses (pinned memory) +3. gather_indices tensor has fixed address, variable content +4. H2D transfer uses gathered_copy kernel inside CUDA graphs +5. Graph replay only needs index updates (tiny overhead) +""" + +from collections import deque +from dataclasses import dataclass, field +from enum import Enum, auto +from typing import List, Tuple, Dict, Set, Optional +import torch +from torch import Tensor + +from nanovllm.engine.sequence import Sequence +from nanovllm.kvcache.base_manager import KVCacheManager +from nanovllm.kvcache.offload_engine import OffloadEngine +from nanovllm.kvcache.policies.base_policy import EvictionPolicy +from nanovllm.kvcache.policies.lru_policy import LRUPolicy + + +class BlockLocation(Enum): + """Where a logical block's data currently resides.""" + GPU = auto() + CPU = auto() + INVALID = auto() # Not yet written / deallocated + + +@dataclass +class LogicalBlock: + """ + Logical block that can be mapped to GPU or CPU physical storage. + + Sequences reference logical blocks. Physical blocks are the actual + storage locations (GPU slots or CPU blocks). + """ + logical_id: int + location: BlockLocation = BlockLocation.INVALID + gpu_slot: int = -1 # GPU buffer slot ID (if on GPU) + cpu_block_id: int = -1 # CPU pool block ID (if on CPU) + ref_count: int = 0 + hash: int = -1 + token_ids: List[int] = field(default_factory=list) + + def reset(self): + self.location = BlockLocation.INVALID + self.gpu_slot = -1 + self.cpu_block_id = -1 + self.ref_count = 0 + self.hash = -1 + self.token_ids = [] + + +class HybridKVCacheManager(KVCacheManager): + """ + Hybrid CPU-GPU KV cache manager with CUDA Graph support. + + Architecture: + - GPU buffer: Fixed-size working set (num_gpu_slots) + - CPU pool: Overflow storage (num_cpu_blocks) + - Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks) + + CUDA Graph compatibility: + - All tensor addresses fixed at init time + - prepare_for_attention() updates gather_indices (outside graph) + - gathered_h2d_layer() executes transfer (inside graph) + + Strategy: + 1. New KV data written to GPU slots + 2. Cold blocks evicted to CPU using configurable policy + 3. Needed blocks prefetched back to GPU before attention + """ + + def __init__( + self, + num_gpu_slots: int, + num_cpu_blocks: int, + block_size: int, + policy: Optional[EvictionPolicy] = None, + ): + """ + Initialize hybrid manager. + + Args: + num_gpu_slots: Number of GPU buffer slots (working set) + num_cpu_blocks: Number of CPU pool blocks (overflow) + block_size: Tokens per block + policy: Eviction policy (default: LRU) + """ + self._block_size = block_size + self.num_gpu_slots = num_gpu_slots + self.num_cpu_blocks = num_cpu_blocks + self.total_blocks = num_gpu_slots + num_cpu_blocks + + # Eviction policy + self.policy = policy or LRUPolicy() + + # Logical blocks (what sequences reference) + self.logical_blocks: List[LogicalBlock] = [ + LogicalBlock(i) for i in range(self.total_blocks) + ] + self.free_logical_ids: deque[int] = deque(range(self.total_blocks)) + + # GPU slot management (slots are fixed, mapping is variable) + self.free_gpu_slots: deque[int] = deque(range(num_gpu_slots)) + self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id + + # CPU block management + self.free_cpu_blocks: deque[int] = deque(range(num_cpu_blocks)) + self.cpu_block_to_logical: Dict[int, int] = {} # cpu_block -> logical_id + + # Prefix cache (uses logical block IDs) + self.hash_to_logical_id: Dict[int, int] = {} + + # Step counter for policy + self.current_step = 0 + + # Offload engine (set by allocate_cache) + self.offload_engine: Optional[OffloadEngine] = None + + # Track blocks pending GPU load (for decode graph) + self.pending_gpu_loads: Set[int] = set() # logical_ids + + # Track blocks that have been prefilled (KV written) for chunked prefill + self.prefilled_blocks: Set[int] = set() # logical_ids + + @property + def block_size(self) -> int: + return self._block_size + + @property + def num_free_blocks(self) -> int: + return len(self.free_logical_ids) + + def allocate_cache( + self, + num_layers: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype, + ) -> None: + """Initialize the offload engine with actual cache storage.""" + self.offload_engine = OffloadEngine( + num_layers=num_layers, + num_gpu_blocks=self.num_gpu_slots, + num_cpu_blocks=self.num_cpu_blocks, + block_size=self._block_size, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + dtype=dtype, + ) + + def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: + """Get GPU K/V cache tensors for a layer.""" + assert self.offload_engine is not None + return self.offload_engine.get_layer_cache(layer_id) + + def _allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> int: + """ + Get a free GPU slot, evicting if necessary. + + Args: + protected_logical_ids: Logical block IDs that cannot be evicted + + Returns: + GPU slot ID + + Raises: + RuntimeError: If no GPU slot is available + """ + if self.free_gpu_slots: + return self.free_gpu_slots.popleft() + + # Need to evict - find victim using policy + return self._evict_to_cpu(protected_logical_ids) + + def _try_allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> Optional[int]: + """ + Try to get a free GPU slot, evicting if necessary. + + Unlike _allocate_gpu_slot(), returns None instead of raising if no eviction possible. + + Args: + protected_logical_ids: Logical block IDs that cannot be evicted + + Returns: + GPU slot ID, or None if no slot available + """ + if self.free_gpu_slots: + return self.free_gpu_slots.popleft() + + # Check if we can evict + protected = protected_logical_ids or set() + for gpu_slot, logical_id in self.gpu_slot_to_logical.items(): + if logical_id not in protected: + block = self.logical_blocks[logical_id] + if block.ref_count > 0: + # Found evictable block + return self._evict_to_cpu(protected_logical_ids) + + # No evictable blocks + return None + + def _evict_to_cpu(self, protected_logical_ids: Optional[Set[int]] = None) -> int: + """ + Evict a GPU block to CPU to make room. + + Args: + protected_logical_ids: Logical block IDs that cannot be evicted + + Returns: + The freed GPU slot ID + """ + protected = protected_logical_ids or set() + + # Find candidates (blocks currently on GPU with ref_count > 0, excluding protected) + candidates: Set[int] = set() + for gpu_slot, logical_id in self.gpu_slot_to_logical.items(): + if logical_id in protected: + continue # Skip protected blocks + block = self.logical_blocks[logical_id] + if block.ref_count > 0: # Only evict blocks still in use + candidates.add(gpu_slot) + + if not candidates: + raise RuntimeError( + f"No GPU slots available for eviction. " + f"GPU slots: {self.num_gpu_slots}, protected: {len(protected)}, " + f"need more GPU memory or reduce sequence length" + ) + + # Use policy to select victim + victim_gpu_slot = self.policy.select_victim(candidates) + logical_id = self.gpu_slot_to_logical[victim_gpu_slot] + block = self.logical_blocks[logical_id] + + # Allocate CPU block + if not self.free_cpu_blocks: + raise RuntimeError("Both GPU and CPU are full") + cpu_block_id = self.free_cpu_blocks.popleft() + + # Async offload GPU -> CPU + self.offload_engine.offload_block_async( + layer_id=0, # TODO: handle per-layer offloading + gpu_block_id=victim_gpu_slot, + cpu_block_id=cpu_block_id, + ) + + # Update mappings + del self.gpu_slot_to_logical[victim_gpu_slot] + self.cpu_block_to_logical[cpu_block_id] = logical_id + + block.location = BlockLocation.CPU + block.gpu_slot = -1 + block.cpu_block_id = cpu_block_id + + # Notify policy + self.policy.on_block_evicted(victim_gpu_slot) + + return victim_gpu_slot + + def _ensure_on_gpu( + self, + logical_id: int, + protected_logical_ids: Optional[Set[int]] = None, + ) -> int: + """ + Ensure a logical block is on GPU. + + Args: + logical_id: Logical block ID + protected_logical_ids: Logical block IDs that cannot be evicted + + Returns: + GPU slot ID where the block is/will be + """ + block = self.logical_blocks[logical_id] + + if block.location == BlockLocation.GPU: + # Already on GPU, update policy + self.policy.on_block_access(block.gpu_slot, self.current_step) + return block.gpu_slot + + if block.location == BlockLocation.CPU: + # Need to prefetch from CPU + gpu_slot = self._allocate_gpu_slot(protected_logical_ids) + + # Async prefetch CPU -> GPU + self.offload_engine.prefetch_block_async( + layer_id=0, # TODO: handle per-layer + cpu_block_id=block.cpu_block_id, + gpu_block_id=gpu_slot, + ) + + # Update mappings + self.free_cpu_blocks.append(block.cpu_block_id) + del self.cpu_block_to_logical[block.cpu_block_id] + + self.gpu_slot_to_logical[gpu_slot] = logical_id + + block.location = BlockLocation.GPU + block.gpu_slot = gpu_slot + block.cpu_block_id = -1 + + # Notify policy + self.policy.on_block_prefetched(gpu_slot, self.current_step) + + return gpu_slot + + raise RuntimeError(f"Block {logical_id} is in invalid state") + + def can_allocate(self, seq: Sequence) -> bool: + """Check if we can allocate blocks for a new sequence.""" + return len(self.free_logical_ids) >= seq.num_blocks + + def allocate(self, seq: Sequence) -> None: + """ + Allocate logical blocks for prefill. + + New blocks are allocated on GPU when possible. If GPU is full and all + GPU blocks belong to this sequence (can't evict), remaining blocks + are allocated to CPU for chunked prefill. + """ + assert not seq.block_table, "Sequence already has blocks" + + h = -1 + cache_miss = False + + # Track blocks allocated for this sequence to protect them from eviction + allocated_for_seq: Set[int] = set() + + for i in range(seq.num_blocks): + token_ids = seq.block(i) + + # Hash for full blocks only + if len(token_ids) == self._block_size: + h = self.compute_hash(token_ids, h) + else: + h = -1 + + # Check prefix cache + cached_logical_id = self.hash_to_logical_id.get(h, -1) + if cached_logical_id != -1: + cached_block = self.logical_blocks[cached_logical_id] + if cached_block.token_ids == token_ids and cached_block.ref_count > 0: + # Cache hit + cached_block.ref_count += 1 + seq.num_cached_tokens += self._block_size + seq.block_table.append(cached_logical_id) + allocated_for_seq.add(cached_logical_id) + + # Ensure block is on GPU (protect already allocated blocks) + if cached_block.location == BlockLocation.CPU: + self._ensure_on_gpu(cached_logical_id, allocated_for_seq) + + continue + + cache_miss = True + + # Allocate new logical block + logical_id = self.free_logical_ids.popleft() + block = self.logical_blocks[logical_id] + block.ref_count = 1 + block.hash = h + block.token_ids = token_ids.copy() if len(token_ids) == self._block_size else [] + + # Try to allocate GPU slot + gpu_slot = self._try_allocate_gpu_slot(allocated_for_seq) + if gpu_slot is not None: + # Got GPU slot + block.location = BlockLocation.GPU + block.gpu_slot = gpu_slot + block.cpu_block_id = -1 + self.gpu_slot_to_logical[gpu_slot] = logical_id + else: + # GPU full and can't evict (all protected) - allocate to CPU + # This block will be written via chunked prefill + if not self.free_cpu_blocks: + raise RuntimeError( + f"Both GPU and CPU are full. Need {seq.num_blocks} blocks, " + f"GPU has {self.num_gpu_slots}, CPU has {self.num_cpu_blocks}" + ) + cpu_block_id = self.free_cpu_blocks.popleft() + block.location = BlockLocation.CPU + block.gpu_slot = -1 + block.cpu_block_id = cpu_block_id + self.cpu_block_to_logical[cpu_block_id] = logical_id + + allocated_for_seq.add(logical_id) + + # Update prefix cache + if h != -1: + self.hash_to_logical_id[h] = logical_id + + # Notify policy + self.policy.on_block_allocated(gpu_slot, self.current_step) + + seq.block_table.append(logical_id) + + def deallocate(self, seq: Sequence) -> None: + """Release all blocks for a sequence.""" + for logical_id in reversed(seq.block_table): + block = self.logical_blocks[logical_id] + block.ref_count -= 1 + + if block.ref_count == 0: + # Free physical block + if block.location == BlockLocation.GPU: + self.free_gpu_slots.append(block.gpu_slot) + del self.gpu_slot_to_logical[block.gpu_slot] + self.policy.on_block_deallocated(block.gpu_slot) + elif block.location == BlockLocation.CPU: + self.free_cpu_blocks.append(block.cpu_block_id) + del self.cpu_block_to_logical[block.cpu_block_id] + + # Free logical block + block.reset() + self.free_logical_ids.append(logical_id) + + # Remove from prefilled tracking + self.prefilled_blocks.discard(logical_id) + + seq.num_cached_tokens = 0 + seq.block_table.clear() + + def can_append(self, seq: Sequence) -> bool: + """Check if we can append a token.""" + need_new_block = (len(seq) % self._block_size == 1) + return len(self.free_logical_ids) >= int(need_new_block) + + def may_append(self, seq: Sequence) -> None: + """Handle potential new block allocation during decode.""" + block_table = seq.block_table + last_logical_id = block_table[-1] + last_block = self.logical_blocks[last_logical_id] + + seq_len = len(seq) + pos_in_block = seq_len % self._block_size + + if pos_in_block == 1: + # Need new block + assert last_block.hash != -1 + + logical_id = self.free_logical_ids.popleft() + block = self.logical_blocks[logical_id] + block.ref_count = 1 + block.hash = -1 + block.token_ids = [] + + # New decode blocks go to GPU + gpu_slot = self._allocate_gpu_slot() + block.location = BlockLocation.GPU + block.gpu_slot = gpu_slot + + self.gpu_slot_to_logical[gpu_slot] = logical_id + self.policy.on_block_allocated(gpu_slot, self.current_step) + + block_table.append(logical_id) + + elif pos_in_block == 0: + # Block is full, update hash for prefix cache + assert last_block.hash == -1 + token_ids = seq.block(seq.num_blocks - 1) + prefix_hash = ( + self.logical_blocks[block_table[-2]].hash + if len(block_table) > 1 else -1 + ) + h = self.compute_hash(token_ids, prefix_hash) + last_block.hash = h + last_block.token_ids = token_ids.copy() + self.hash_to_logical_id[h] = last_logical_id + + def prepare_for_attention( + self, + seqs: List[Sequence], + is_prefill: bool, + ) -> None: + """ + Prepare KV cache for attention computation. + + For prefill: async prefetch blocks from CPU to GPU. + For decode: update gather_indices for CUDA graph. + """ + self.current_step += 1 + + # Collect all needed logical blocks + needed_logical_ids: Set[int] = set() + for seq in seqs: + needed_logical_ids.update(seq.block_table) + + if is_prefill: + # Prefill: ensure all blocks on GPU (async prefetch) + # Pass needed_logical_ids as protected to prevent evicting blocks we need + for logical_id in needed_logical_ids: + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.CPU: + self._ensure_on_gpu(logical_id, needed_logical_ids) + + # Wait for all prefetches to complete + self.offload_engine.wait_all_transfers() + + else: + # Decode: Check if we need chunked decode + cpu_blocks_count = sum( + 1 for lid in needed_logical_ids + if self.logical_blocks[lid].location == BlockLocation.CPU + ) + + if cpu_blocks_count > self.num_gpu_slots: + # Too many blocks on CPU - will use chunked decode + # Don't try to load all blocks now + return + + # Standard decode: prepare gather_indices for CUDA graph + # Identify blocks needing transfer + self.pending_gpu_loads.clear() + mappings_per_layer: List[List[Tuple[int, int]]] = [ + [] for _ in range(self.offload_engine.num_layers) + ] + + for logical_id in needed_logical_ids: + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.CPU: + # Allocate GPU slot (protect needed blocks from eviction) + gpu_slot = self._allocate_gpu_slot(needed_logical_ids) + + # Record mapping for each layer + for layer_id in range(self.offload_engine.num_layers): + mappings_per_layer[layer_id].append( + (block.cpu_block_id, gpu_slot) + ) + + # Update block state + self.free_cpu_blocks.append(block.cpu_block_id) + del self.cpu_block_to_logical[block.cpu_block_id] + + self.gpu_slot_to_logical[gpu_slot] = logical_id + block.location = BlockLocation.GPU + block.gpu_slot = gpu_slot + block.cpu_block_id = -1 + + self.pending_gpu_loads.add(logical_id) + self.policy.on_block_prefetched(gpu_slot, self.current_step) + + elif block.location == BlockLocation.GPU: + self.policy.on_block_access(block.gpu_slot, self.current_step) + + # Update gather indices (outside graph) + self.offload_engine.update_gather_indices_all_layers(mappings_per_layer) + self.offload_engine.sync_indices() + + def needs_chunked_decode(self, seq: Sequence) -> bool: + """ + Check if sequence needs chunked decode. + + Returns True if there are blocks on CPU and total blocks exceed GPU capacity. + """ + cpu_blocks = 0 + for logical_id in seq.block_table: + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.CPU: + cpu_blocks += 1 + return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots + + def load_all_kv_for_layer( + self, + seq: Sequence, + layer_id: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Load ALL KV for a sequence from both GPU and CPU for a layer. + + Used during chunked decode to compute full attention. + + Returns: + (k, v) tensors with shape [1, total_tokens, kv_heads, head_dim] + """ + k_chunks = [] + v_chunks = [] + + for logical_id in seq.block_table: + block = self.logical_blocks[logical_id] + + if block.location == BlockLocation.GPU: + # Get from GPU cache + k, v = self.offload_engine.get_layer_cache(layer_id) + # k, v shape: [num_gpu_blocks, block_size, kv_heads, head_dim] + k_block = k[block.gpu_slot] # [block_size, kv_heads, head_dim] + v_block = v[block.gpu_slot] + k_chunks.append(k_block) + v_chunks.append(v_block) + + elif block.location == BlockLocation.CPU: + # Get from CPU cache + k_block, v_block = self.offload_engine.get_cpu_block(layer_id, block.cpu_block_id) + # Already [block_size, kv_heads, head_dim] + k_chunks.append(k_block.to("cuda", non_blocking=True)) + v_chunks.append(v_block.to("cuda", non_blocking=True)) + + # Concatenate all chunks + k_all = torch.cat(k_chunks, dim=0) # [total_tokens, kv_heads, head_dim] + v_all = torch.cat(v_chunks, dim=0) + + # Add batch dimension + k_all = k_all.unsqueeze(0) # [1, total_tokens, kv_heads, head_dim] + v_all = v_all.unsqueeze(0) + + return k_all, v_all + + def get_gpu_block_tables( + self, + seqs: List[Sequence], + ) -> List[List[int]]: + """ + Get GPU slot tables for sequences. + + Returns GPU slot IDs, which may differ from logical block IDs. + """ + result = [] + for seq in seqs: + gpu_table = [] + for logical_id in seq.block_table: + block = self.logical_blocks[logical_id] + assert block.location == BlockLocation.GPU, ( + f"Block {logical_id} not on GPU (location={block.location})" + ) + gpu_table.append(block.gpu_slot) + result.append(gpu_table) + return result + + def post_attention_cleanup( + self, + seqs: List[Sequence], + is_prefill: bool, + ) -> None: + """ + Cleanup after attention. + + Clear pending loads and optionally proactive offload. + """ + self.pending_gpu_loads.clear() + + # ========== Chunked Prefill Support ========== + + def needs_chunked_prefill(self, seq: Sequence) -> bool: + """ + Check if sequence needs chunked prefill. + + Returns True if there are unprefilled blocks that are on CPU. + This indicates we need to process in chunks because not all blocks fit on GPU. + """ + for logical_id in seq.block_table: + if logical_id not in self.prefilled_blocks: + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.CPU: + return True + return False + + def get_gpu_block_count(self, seq: Sequence) -> int: + """Get number of blocks currently on GPU for this sequence.""" + count = 0 + for logical_id in seq.block_table: + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.GPU: + count += 1 + return count + + def get_prefill_chunk_info(self, seq: Sequence) -> Tuple[int, int, List[int]]: + """ + Get information for current prefill chunk. + + Returns: + (start_block_idx, end_block_idx, gpu_block_ids) + - start_block_idx: First block index in this chunk + - end_block_idx: Last block index (exclusive) in this chunk + - gpu_block_ids: GPU slot IDs for blocks in this chunk + """ + start_idx = -1 + end_idx = -1 + gpu_block_ids = [] + + for i, logical_id in enumerate(seq.block_table): + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.GPU: + if start_idx == -1: + start_idx = i + end_idx = i + 1 + gpu_block_ids.append(block.gpu_slot) + elif start_idx != -1: + # Found CPU block after GPU blocks - stop here + break + + if start_idx == -1: + return (0, 0, []) + + return (start_idx, end_idx, gpu_block_ids) + + def complete_prefill_chunk(self, seq: Sequence) -> bool: + """ + Complete a prefill chunk: mark blocks as prefilled, offload to CPU, load next chunk. + + Returns: + True if there are more chunks to process, False if done. + """ + # Find blocks currently on GPU that were just prefilled + gpu_blocks_to_offload = [] + for logical_id in seq.block_table: + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.GPU and logical_id not in self.prefilled_blocks: + # Mark as prefilled + self.prefilled_blocks.add(logical_id) + gpu_blocks_to_offload.append(logical_id) + + # Offload prefilled GPU blocks to CPU + for logical_id in gpu_blocks_to_offload: + block = self.logical_blocks[logical_id] + if not self.free_cpu_blocks: + raise RuntimeError("No free CPU blocks for offload") + + cpu_block_id = self.free_cpu_blocks.popleft() + + # Async offload all layers + for layer_id in range(self.offload_engine.num_layers): + self.offload_engine.offload_block_async( + layer_id=layer_id, + gpu_block_id=block.gpu_slot, + cpu_block_id=cpu_block_id, + ) + + # Update mappings + self.free_gpu_slots.append(block.gpu_slot) + del self.gpu_slot_to_logical[block.gpu_slot] + self.cpu_block_to_logical[cpu_block_id] = logical_id + + block.location = BlockLocation.CPU + block.cpu_block_id = cpu_block_id + block.gpu_slot = -1 + + # Wait for offload to complete + self.offload_engine.wait_all_transfers() + + # Find next UNPREFILLED CPU blocks and bring them to GPU + cpu_blocks_to_load = [] + for logical_id in seq.block_table: + if logical_id in self.prefilled_blocks: + continue # Skip already prefilled + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.CPU: + if len(cpu_blocks_to_load) >= self.num_gpu_slots: + break # GPU is full + cpu_blocks_to_load.append(logical_id) + + if not cpu_blocks_to_load: + return False # All blocks have been prefilled + + # Load unprefilled CPU blocks to GPU + for logical_id in cpu_blocks_to_load: + block = self.logical_blocks[logical_id] + gpu_slot = self.free_gpu_slots.popleft() + + # Note: We're NOT prefetching existing data - these blocks are being + # loaded for the first time, so we just need to assign GPU slots + # The model will write new KV cache data to these slots + + # Update mappings + self.free_cpu_blocks.append(block.cpu_block_id) + del self.cpu_block_to_logical[block.cpu_block_id] + self.gpu_slot_to_logical[gpu_slot] = logical_id + + block.location = BlockLocation.GPU + block.gpu_slot = gpu_slot + block.cpu_block_id = -1 + + return True # More chunks to process + + def get_gpu_block_tables_partial( + self, + seqs: List[Sequence], + ) -> List[Tuple[List[int], int, int]]: + """ + Get GPU block tables for chunked prefill. + + Returns list of (gpu_block_ids, start_block_idx, end_block_idx) per sequence. + Only includes blocks that are currently on GPU AND haven't been prefilled yet. + """ + result = [] + for seq in seqs: + gpu_table = [] + start_idx = -1 + end_idx = -1 + + for i, logical_id in enumerate(seq.block_table): + # Skip already prefilled blocks + if logical_id in self.prefilled_blocks: + continue + + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.GPU: + if start_idx == -1: + start_idx = i + end_idx = i + 1 + gpu_table.append(block.gpu_slot) + elif start_idx != -1: + # Stop at first non-GPU block after GPU blocks + break + + if start_idx == -1: + start_idx = 0 + end_idx = 0 + + result.append((gpu_table, start_idx, end_idx)) + return result + + def get_prefilled_cpu_blocks(self, seq: Sequence) -> List[int]: + """ + Get list of CPU block IDs for blocks that have been prefilled. + + Used for loading previous KV during chunked prefill. + + Returns: + List of CPU block IDs in sequence order + """ + cpu_blocks = [] + for logical_id in seq.block_table: + if logical_id in self.prefilled_blocks: + block = self.logical_blocks[logical_id] + if block.location == BlockLocation.CPU: + cpu_blocks.append(block.cpu_block_id) + return cpu_blocks + + def load_prev_kv_for_layer( + self, + seq: Sequence, + layer_id: int, + ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Load previous prefilled KV from CPU for a specific layer. + + This concatenates KV from all previously prefilled blocks for use + during chunked prefill attention. + + Args: + seq: Sequence to load KV for + layer_id: Layer index + + Returns: + (k, v) tensors with shape [1, total_prev_tokens, kv_heads, head_dim] + or (None, None) if no previous KV exists + """ + cpu_blocks = self.get_prefilled_cpu_blocks(seq) + if not cpu_blocks: + return None, None + + k_chunks = [] + v_chunks = [] + + for cpu_block_id in cpu_blocks: + k, v = self.offload_engine.get_cpu_block(layer_id, cpu_block_id) + # k, v shape: [block_size, kv_heads, head_dim] + k_chunks.append(k) + v_chunks.append(v) + + # Concatenate all chunks + k_prev = torch.cat(k_chunks, dim=0) # [total_prev_tokens, kv_heads, head_dim] + v_prev = torch.cat(v_chunks, dim=0) + + # Move to GPU and add batch dimension + k_prev = k_prev.to("cuda", non_blocking=True).unsqueeze(0) # [1, tokens, heads, dim] + v_prev = v_prev.to("cuda", non_blocking=True).unsqueeze(0) + + return k_prev, v_prev + + def get_chunk_start_position(self, seq: Sequence) -> int: + """ + Get the starting token position for the current chunk. + + This is the total number of tokens in previously prefilled blocks. + + Returns: + Token position offset for current chunk + """ + pos = 0 + for logical_id in seq.block_table: + if logical_id in self.prefilled_blocks: + # Full block's worth of tokens + pos += self._block_size + else: + break + return pos + + def __repr__(self) -> str: + return ( + f"HybridKVCacheManager(\n" + f" num_gpu_slots={self.num_gpu_slots},\n" + f" num_cpu_blocks={self.num_cpu_blocks},\n" + f" block_size={self._block_size},\n" + f" free_logical={len(self.free_logical_ids)},\n" + f" free_gpu={len(self.free_gpu_slots)},\n" + f" free_cpu={len(self.free_cpu_blocks)},\n" + f" policy={self.policy}\n" + f")" + ) diff --git a/nanovllm/kvcache/kernels.py b/nanovllm/kvcache/kernels.py new file mode 100644 index 0000000..db2e046 --- /dev/null +++ b/nanovllm/kvcache/kernels.py @@ -0,0 +1,190 @@ +""" +Triton kernels for CPU-GPU KV cache transfer. + +These kernels are designed to be CUDA Graph compatible: +- All tensor addresses are fixed at graph capture time +- Only the content of index tensors changes between replays +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def gathered_copy_kernel( + src_ptr, # Source tensor base pointer (CPU pinned or GPU) + dst_ptr, # Destination tensor base pointer (GPU) + indices_ptr, # Gather indices [num_dst_blocks] + num_dst_blocks, # Number of destination blocks + block_numel: tl.constexpr, # Elements per block (block_size * kv_heads * head_dim) + BLOCK_SIZE: tl.constexpr = 1024, +): + """ + Gathered copy kernel: dst[i] = src[indices[i]] + + Each program instance handles one destination block. + The indices tensor specifies which source block to copy from. + + This kernel is CUDA Graph compatible because: + - src_ptr, dst_ptr, indices_ptr addresses are fixed + - Only indices content changes between graph replays + + Args: + src_ptr: Base pointer to source blocks [num_src_blocks, block_numel] + dst_ptr: Base pointer to destination blocks [num_dst_blocks, block_numel] + indices_ptr: Gather indices [num_dst_blocks], each value is a source block index + num_dst_blocks: Number of destination blocks to copy + block_numel: Number of elements per block + BLOCK_SIZE: Triton block size for parallelization + """ + dst_block_idx = tl.program_id(0) + + # Skip if out of range + if dst_block_idx >= num_dst_blocks: + return + + # Load source block index from indices tensor + src_block_idx = tl.load(indices_ptr + dst_block_idx) + + # Skip if index is -1 (invalid/no-op marker) + if src_block_idx < 0: + return + + # Calculate base offsets + src_base = src_block_idx * block_numel + dst_base = dst_block_idx * block_numel + + # Copy block data in chunks of BLOCK_SIZE + for start in range(0, block_numel, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < block_numel + + # Load from source and store to destination + data = tl.load(src_ptr + src_base + offsets, mask=mask) + tl.store(dst_ptr + dst_base + offsets, data, mask=mask) + + +@triton.jit +def gathered_copy_kv_kernel( + k_src_ptr, # K cache source [num_src_blocks, block_size, kv_heads, head_dim] + v_src_ptr, # V cache source + k_dst_ptr, # K cache destination + v_dst_ptr, # V cache destination + indices_ptr, # Gather indices [num_dst_blocks] + num_dst_blocks, # Number of destination blocks + block_numel: tl.constexpr, # Elements per block + BLOCK_SIZE: tl.constexpr = 1024, +): + """ + Gathered copy for both K and V caches simultaneously. + + More efficient than calling gathered_copy_kernel twice because: + - Single kernel launch overhead + - Better memory access patterns when K and V are accessed together + """ + dst_block_idx = tl.program_id(0) + + if dst_block_idx >= num_dst_blocks: + return + + src_block_idx = tl.load(indices_ptr + dst_block_idx) + + if src_block_idx < 0: + return + + src_base = src_block_idx * block_numel + dst_base = dst_block_idx * block_numel + + for start in range(0, block_numel, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < block_numel + + # Copy K cache + k_data = tl.load(k_src_ptr + src_base + offsets, mask=mask) + tl.store(k_dst_ptr + dst_base + offsets, k_data, mask=mask) + + # Copy V cache + v_data = tl.load(v_src_ptr + src_base + offsets, mask=mask) + tl.store(v_dst_ptr + dst_base + offsets, v_data, mask=mask) + + +def gathered_copy( + src: torch.Tensor, + dst: torch.Tensor, + indices: torch.Tensor, +) -> None: + """ + Perform gathered copy: dst[i] = src[indices[i]] + + Args: + src: Source tensor [num_src_blocks, ...] + dst: Destination tensor [num_dst_blocks, ...] + indices: Index tensor [num_dst_blocks], dtype=int64 + -1 means skip (no-op) + + Note: + - src can be on CPU (pinned memory) or GPU + - dst must be on GPU + - indices must be on GPU + - All shapes after first dimension must match + """ + assert dst.is_cuda, "Destination must be on GPU" + assert indices.is_cuda, "Indices must be on GPU" + assert src.shape[1:] == dst.shape[1:], "Shape mismatch after first dimension" + + num_dst_blocks = dst.shape[0] + block_numel = dst[0].numel() + + # Flatten for kernel + src_flat = src.view(src.shape[0], -1) + dst_flat = dst.view(dst.shape[0], -1) + + grid = (num_dst_blocks,) + gathered_copy_kernel[grid]( + src_flat, + dst_flat, + indices, + num_dst_blocks, + block_numel=block_numel, + ) + + +def gathered_copy_kv( + k_src: torch.Tensor, + v_src: torch.Tensor, + k_dst: torch.Tensor, + v_dst: torch.Tensor, + indices: torch.Tensor, +) -> None: + """ + Perform gathered copy for both K and V caches. + + Args: + k_src, v_src: Source K/V caches [num_src_blocks, block_size, kv_heads, head_dim] + k_dst, v_dst: Destination K/V caches [num_dst_blocks, block_size, kv_heads, head_dim] + indices: Index tensor [num_dst_blocks], dtype=int64 + """ + assert k_dst.is_cuda and v_dst.is_cuda, "Destinations must be on GPU" + assert indices.is_cuda, "Indices must be on GPU" + assert k_src.shape[1:] == k_dst.shape[1:], "K shape mismatch" + assert v_src.shape[1:] == v_dst.shape[1:], "V shape mismatch" + + num_dst_blocks = k_dst.shape[0] + block_numel = k_dst[0].numel() + + k_src_flat = k_src.view(k_src.shape[0], -1) + v_src_flat = v_src.view(v_src.shape[0], -1) + k_dst_flat = k_dst.view(k_dst.shape[0], -1) + v_dst_flat = v_dst.view(v_dst.shape[0], -1) + + grid = (num_dst_blocks,) + gathered_copy_kv_kernel[grid]( + k_src_flat, + v_src_flat, + k_dst_flat, + v_dst_flat, + indices, + num_dst_blocks, + block_numel=block_numel, + ) diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py new file mode 100644 index 0000000..188b74f --- /dev/null +++ b/nanovllm/kvcache/offload_engine.py @@ -0,0 +1,400 @@ +""" +High-performance CPU-GPU KV cache transfer engine. + +Key design principles for CUDA Graph compatibility: +1. All tensor addresses are fixed at initialization +2. Only index tensor contents change between graph replays +3. Supports both async transfer (for prefill) and graph-based transfer (for decode) +""" + +import torch +from torch import Tensor +from typing import Dict, List, Tuple, Optional +from dataclasses import dataclass + +from nanovllm.kvcache.kernels import gathered_copy_kv + + +@dataclass +class TransferEvent: + """Tracks a pending async transfer.""" + event: torch.cuda.Event + layer_id: int + src_block_id: int + dst_block_id: int + direction: str # "h2d" or "d2h" + + +class OffloadEngine: + """ + High-performance CPU-GPU async transfer engine for KV cache offloading. + + Memory layout: + - GPU cache: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] + - CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned) + - Gather indices: [num_layers, num_gpu_blocks] (fixed address, variable content) + + CUDA Graph compatibility: + - gathered_h2d_layer() can be captured into CUDA graphs + - update_gather_indices() is called outside graphs to prepare indices + - All tensor addresses remain fixed across graph replays + """ + + def __init__( + self, + num_layers: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + block_size: int, + num_kv_heads: int, + head_dim: int, + dtype: torch.dtype = torch.float16, + num_streams: int = 4, + ): + self.num_layers = num_layers + self.num_gpu_blocks = num_gpu_blocks + self.num_cpu_blocks = num_cpu_blocks + self.block_size = block_size + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.dtype = dtype + self.kv_dim = num_kv_heads * head_dim + self.block_numel = block_size * self.kv_dim + + # ========== Fixed-address GPU KV cache ========== + # Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] + self.k_cache_gpu = torch.empty( + num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cuda" + ) + self.v_cache_gpu = torch.empty( + num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cuda" + ) + + # ========== Fixed-address CPU KV cache (pinned memory) ========== + self.k_cache_cpu = torch.empty( + num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cpu", pin_memory=True + ) + self.v_cache_cpu = torch.empty( + num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cpu", pin_memory=True + ) + + # ========== Fixed-address gather indices (content is variable) ========== + # gather_indices[layer][i] = CPU block id to copy to GPU slot i + # -1 means no-op (skip this slot) + self.gather_indices_cpu = torch.empty( + num_layers, num_gpu_blocks, + dtype=torch.int64, device="cpu", pin_memory=True + ) + self.gather_indices_cpu.fill_(-1) + self.gather_indices_gpu = torch.full( + (num_layers, num_gpu_blocks), -1, + dtype=torch.int64, device="cuda" + ) + + # ========== Transfer streams for async operations ========== + self.transfer_streams = [torch.cuda.Stream() for _ in range(num_streams)] + self.compute_stream = torch.cuda.current_stream() + self._stream_idx = 0 + + # ========== Event tracking for async transfers ========== + self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {} + + def _get_next_stream(self) -> torch.cuda.Stream: + """Round-robin stream selection for parallel transfers.""" + stream = self.transfer_streams[self._stream_idx] + self._stream_idx = (self._stream_idx + 1) % len(self.transfer_streams) + return stream + + # ========== CUDA Graph compatible methods ========== + + def gathered_h2d_layer(self, layer_id: int) -> None: + """ + Execute gathered H2D copy for a single layer. + + This method is CUDA Graph compatible - can be captured into a graph. + Before calling, update_gather_indices() must be called to set up + which CPU blocks to copy to which GPU slots. + + Args: + layer_id: Layer index to transfer + """ + gathered_copy_kv( + k_src=self.k_cache_cpu[layer_id], + v_src=self.v_cache_cpu[layer_id], + k_dst=self.k_cache_gpu[layer_id], + v_dst=self.v_cache_gpu[layer_id], + indices=self.gather_indices_gpu[layer_id], + ) + + def gathered_h2d_all_layers(self) -> None: + """ + Execute gathered H2D copy for all layers. + + CUDA Graph compatible - can be captured into a single graph. + """ + for layer_id in range(self.num_layers): + self.gathered_h2d_layer(layer_id) + + def update_gather_indices( + self, + layer_id: int, + mappings: List[Tuple[int, int]], + ) -> None: + """ + Update gather indices for a layer (call OUTSIDE CUDA graph). + + Args: + layer_id: Layer index + mappings: List of (cpu_block_id, gpu_slot) tuples + Only these slots will be updated; others keep their values + """ + for cpu_block_id, gpu_slot in mappings: + self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id + + # Async copy to GPU + self.gather_indices_gpu[layer_id].copy_( + self.gather_indices_cpu[layer_id], + non_blocking=True + ) + + def update_gather_indices_all_layers( + self, + mappings_per_layer: List[List[Tuple[int, int]]], + ) -> None: + """ + Update gather indices for all layers. + + Args: + mappings_per_layer: mappings_per_layer[layer_id] = [(cpu_block_id, gpu_slot), ...] + """ + for layer_id, mappings in enumerate(mappings_per_layer): + for cpu_block_id, gpu_slot in mappings: + self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id + + # Batch copy all layers + self.gather_indices_gpu.copy_(self.gather_indices_cpu, non_blocking=True) + + def clear_gather_indices(self, layer_id: Optional[int] = None) -> None: + """ + Clear gather indices (set all to -1, meaning no-op). + + Args: + layer_id: If provided, clear only this layer; otherwise clear all + """ + if layer_id is not None: + self.gather_indices_cpu[layer_id].fill_(-1) + self.gather_indices_gpu[layer_id].fill_(-1) + else: + self.gather_indices_cpu.fill_(-1) + self.gather_indices_gpu.fill_(-1) + + # ========== Async transfer methods (for prefill, outside CUDA graph) ========== + + def prefetch_block_async( + self, + layer_id: int, + cpu_block_id: int, + gpu_block_id: int, + ) -> torch.cuda.Event: + """ + Async prefetch a single block from CPU to GPU. + + For use in prefill phase where CUDA graphs are not used. + + Args: + layer_id: Layer index + cpu_block_id: Source block in CPU cache + gpu_block_id: Destination slot in GPU cache + + Returns: + CUDA event that signals completion + """ + stream = self._get_next_stream() + event = torch.cuda.Event() + + with torch.cuda.stream(stream): + # K cache + self.k_cache_gpu[layer_id, gpu_block_id].copy_( + self.k_cache_cpu[layer_id, cpu_block_id], + non_blocking=True + ) + # V cache + self.v_cache_gpu[layer_id, gpu_block_id].copy_( + self.v_cache_cpu[layer_id, cpu_block_id], + non_blocking=True + ) + event.record() + + self.pending_events[(layer_id, gpu_block_id)] = event + return event + + def prefetch_blocks_batch_async( + self, + transfers: List[Tuple[int, int, int]], # [(layer_id, cpu_block_id, gpu_block_id), ...] + ) -> List[torch.cuda.Event]: + """ + Batch async prefetch multiple blocks. + + Args: + transfers: List of (layer_id, cpu_block_id, gpu_block_id) tuples + + Returns: + List of CUDA events for each transfer + """ + events = [] + for layer_id, cpu_block_id, gpu_block_id in transfers: + event = self.prefetch_block_async(layer_id, cpu_block_id, gpu_block_id) + events.append(event) + return events + + def offload_block_async( + self, + layer_id: int, + gpu_block_id: int, + cpu_block_id: int, + ) -> torch.cuda.Event: + """ + Async offload a block from GPU to CPU. + + Args: + layer_id: Layer index + gpu_block_id: Source slot in GPU cache + cpu_block_id: Destination block in CPU cache + + Returns: + CUDA event that signals completion + """ + stream = self._get_next_stream() + event = torch.cuda.Event() + + with torch.cuda.stream(stream): + # Wait for any compute using this block + stream.wait_stream(self.compute_stream) + + # K cache + self.k_cache_cpu[layer_id, cpu_block_id].copy_( + self.k_cache_gpu[layer_id, gpu_block_id], + non_blocking=True + ) + # V cache + self.v_cache_cpu[layer_id, cpu_block_id].copy_( + self.v_cache_gpu[layer_id, gpu_block_id], + non_blocking=True + ) + event.record() + + return event + + def offload_blocks_batch_async( + self, + transfers: List[Tuple[int, int, int]], # [(layer_id, gpu_block_id, cpu_block_id), ...] + ) -> List[torch.cuda.Event]: + """ + Batch async offload multiple blocks. + + Args: + transfers: List of (layer_id, gpu_block_id, cpu_block_id) tuples + + Returns: + List of CUDA events + """ + events = [] + for layer_id, gpu_block_id, cpu_block_id in transfers: + event = self.offload_block_async(layer_id, gpu_block_id, cpu_block_id) + events.append(event) + return events + + # ========== Synchronization methods ========== + + def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None: + """Wait for a specific block's transfer to complete.""" + key = (layer_id, gpu_block_id) + if key in self.pending_events: + self.pending_events[key].synchronize() + del self.pending_events[key] + + def wait_all_transfers(self) -> None: + """Wait for all pending transfers to complete.""" + for stream in self.transfer_streams: + stream.synchronize() + self.pending_events.clear() + + def sync_indices(self) -> None: + """Synchronize to ensure all index updates are complete.""" + torch.cuda.current_stream().synchronize() + + # ========== Cache access methods ========== + + def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: + """ + Get GPU K/V cache tensors for a specific layer. + + Returns: + (k_cache, v_cache) tensors for the layer + Shape: [num_gpu_blocks, block_size, kv_heads, head_dim] + """ + return self.k_cache_gpu[layer_id], self.v_cache_gpu[layer_id] + + def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]: + """ + Get full GPU K/V cache tensors. + + Returns: + (k_cache, v_cache) tensors + Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim] + """ + return self.k_cache_gpu, self.v_cache_gpu + + def get_cpu_block( + self, + layer_id: int, + cpu_block_id: int, + ) -> Tuple[Tensor, Tensor]: + """ + Get a specific CPU block's K/V cache. + + Returns: + (k_cache, v_cache) for the block + Shape: [block_size, kv_heads, head_dim] + """ + return ( + self.k_cache_cpu[layer_id, cpu_block_id], + self.v_cache_cpu[layer_id, cpu_block_id], + ) + + # ========== Memory info ========== + + def gpu_memory_bytes(self) -> int: + """Total GPU memory used by KV caches.""" + return ( + self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() + + self.v_cache_gpu.numel() * self.v_cache_gpu.element_size() + + self.gather_indices_gpu.numel() * self.gather_indices_gpu.element_size() + ) + + def cpu_memory_bytes(self) -> int: + """Total CPU memory used by KV caches.""" + return ( + self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() + + self.v_cache_cpu.numel() * self.v_cache_cpu.element_size() + + self.gather_indices_cpu.numel() * self.gather_indices_cpu.element_size() + ) + + def __repr__(self) -> str: + return ( + f"OffloadEngine(\n" + f" num_layers={self.num_layers},\n" + f" num_gpu_blocks={self.num_gpu_blocks},\n" + f" num_cpu_blocks={self.num_cpu_blocks},\n" + f" block_size={self.block_size},\n" + f" kv_heads={self.num_kv_heads},\n" + f" head_dim={self.head_dim},\n" + f" dtype={self.dtype},\n" + f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n" + f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n" + f")" + ) \ No newline at end of file diff --git a/nanovllm/kvcache/policies/__init__.py b/nanovllm/kvcache/policies/__init__.py new file mode 100644 index 0000000..bdab8b4 --- /dev/null +++ b/nanovllm/kvcache/policies/__init__.py @@ -0,0 +1,51 @@ +""" +Eviction policy plugins for KV cache offloading. + +Users can create custom policies by subclassing EvictionPolicy +and specifying the full class path in config.offload_policy. +""" + +from nanovllm.kvcache.policies.base_policy import EvictionPolicy +from nanovllm.kvcache.policies.lru_policy import LRUPolicy +from nanovllm.kvcache.policies.fifo_policy import FIFOPolicy + +# Built-in policy registry +BUILTIN_POLICIES = { + "lru": LRUPolicy, + "fifo": FIFOPolicy, +} + + +def get_policy(policy_name: str) -> EvictionPolicy: + """ + Get an eviction policy instance by name or class path. + + Args: + policy_name: Either a built-in name ("lru", "fifo") or + a full class path ("mymodule.MyPolicy") + + Returns: + EvictionPolicy instance + """ + # Check built-in policies first + if policy_name.lower() in BUILTIN_POLICIES: + return BUILTIN_POLICIES[policy_name.lower()]() + + # Try to import custom policy + try: + module_path, class_name = policy_name.rsplit(".", 1) + import importlib + module = importlib.import_module(module_path) + policy_class = getattr(module, class_name) + if not issubclass(policy_class, EvictionPolicy): + raise TypeError(f"{policy_name} is not a subclass of EvictionPolicy") + return policy_class() + except (ValueError, ImportError, AttributeError) as e: + raise ValueError( + f"Unknown policy '{policy_name}'. " + f"Available built-in policies: {list(BUILTIN_POLICIES.keys())}. " + f"For custom policies, use full class path: 'mymodule.MyPolicy'" + ) from e + + +__all__ = ["EvictionPolicy", "LRUPolicy", "FIFOPolicy", "get_policy", "BUILTIN_POLICIES"] \ No newline at end of file diff --git a/nanovllm/kvcache/policies/base_policy.py b/nanovllm/kvcache/policies/base_policy.py new file mode 100644 index 0000000..0482189 --- /dev/null +++ b/nanovllm/kvcache/policies/base_policy.py @@ -0,0 +1,156 @@ +""" +Base class for eviction policies. + +Users can implement custom policies by subclassing EvictionPolicy +and overriding the abstract methods. +""" + +from abc import ABC, abstractmethod +from typing import Set, Optional + + +class EvictionPolicy(ABC): + """ + Abstract base class for KV cache eviction policies. + + An eviction policy determines which GPU blocks to evict to CPU + when GPU memory is full and new blocks need to be allocated. + + Lifecycle: + 1. on_block_allocated() - called when a new block is allocated + 2. on_block_access() - called each time a block is accessed (e.g., in attention) + 3. select_victim() - called when a block needs to be evicted + 4. on_block_evicted() - called after a block is evicted + + Example custom policy: + ```python + class MyCustomPolicy(EvictionPolicy): + def __init__(self): + self.priorities = {} + + def on_block_allocated(self, block_id: int, step: int): + self.priorities[block_id] = step + + def on_block_access(self, block_id: int, step: int): + # Custom access tracking + pass + + def select_victim(self, candidates: Set[int]) -> int: + # Return block with lowest priority + return min(candidates, key=lambda b: self.priorities.get(b, 0)) + + def on_block_evicted(self, block_id: int): + self.priorities.pop(block_id, None) + ``` + """ + + @abstractmethod + def on_block_allocated(self, block_id: int, step: int) -> None: + """ + Called when a new block is allocated on GPU. + + Args: + block_id: The GPU block ID that was allocated + step: Current inference step (monotonically increasing) + """ + pass + + @abstractmethod + def on_block_access(self, block_id: int, step: int) -> None: + """ + Called when a block is accessed during attention computation. + + Args: + block_id: The GPU block ID being accessed + step: Current inference step + """ + pass + + @abstractmethod + def select_victim(self, candidates: Set[int]) -> int: + """ + Select a block to evict from the candidate set. + + This is called when GPU memory is full and a new block + needs to be allocated. The returned block will be evicted + to CPU. + + Args: + candidates: Set of GPU block IDs that can be evicted + (blocks not currently being used) + + Returns: + Block ID to evict + + Raises: + ValueError: If candidates is empty + """ + pass + + @abstractmethod + def on_block_evicted(self, block_id: int) -> None: + """ + Called after a block is evicted from GPU to CPU. + + Args: + block_id: The GPU block ID that was evicted + """ + pass + + def on_block_prefetched(self, block_id: int, step: int) -> None: + """ + Called when a block is prefetched from CPU back to GPU. + + Default implementation calls on_block_allocated(). + Override for custom behavior. + + Args: + block_id: The GPU block ID that was prefetched to + step: Current inference step + """ + self.on_block_allocated(block_id, step) + + def on_block_deallocated(self, block_id: int) -> None: + """ + Called when a block is fully deallocated (sequence finished). + + Default implementation calls on_block_evicted(). + Override for custom behavior. + + Args: + block_id: The GPU block ID being deallocated + """ + self.on_block_evicted(block_id) + + def reset(self) -> None: + """ + Reset policy state. + + Called when the inference engine is reset. + Default implementation does nothing. + """ + pass + + def get_eviction_order(self, candidates: Set[int], count: int) -> list: + """ + Get multiple blocks to evict in order of priority. + + Default implementation calls select_victim() repeatedly. + Override for more efficient batch selection. + + Args: + candidates: Set of candidate block IDs + count: Number of blocks to evict + + Returns: + List of block IDs to evict, in order + """ + result = [] + remaining = set(candidates) + for _ in range(min(count, len(remaining))): + if not remaining: + break + victim = self.select_victim(remaining) + result.append(victim) + remaining.remove(victim) + return result \ No newline at end of file diff --git a/nanovllm/kvcache/policies/fifo_policy.py b/nanovllm/kvcache/policies/fifo_policy.py new file mode 100644 index 0000000..5b63c91 --- /dev/null +++ b/nanovllm/kvcache/policies/fifo_policy.py @@ -0,0 +1,101 @@ +""" +FIFO (First In, First Out) eviction policy. + +Evicts the block that was allocated earliest. +Simple policy that ignores access patterns. +""" + +from collections import OrderedDict +from typing import Set + +from nanovllm.kvcache.policies.base_policy import EvictionPolicy + + +class FIFOPolicy(EvictionPolicy): + """ + First In, First Out (FIFO) eviction policy. + + Evicts blocks in the order they were allocated, + regardless of access patterns. + + Properties: + - O(1) operations for all methods + - Simple and predictable behavior + - Good for streaming workloads where older data + is naturally less relevant + - Does not adapt to access patterns (unlike LRU) + """ + + def __init__(self): + # OrderedDict maintains insertion order + # Key: block_id, Value: allocation_step + # Oldest (first allocated) is at the front + self.allocation_order: OrderedDict[int, int] = OrderedDict() + + def on_block_allocated(self, block_id: int, step: int) -> None: + """Record allocation order (does not change on access).""" + if block_id not in self.allocation_order: + self.allocation_order[block_id] = step + + def on_block_access(self, block_id: int, step: int) -> None: + """ + FIFO ignores access patterns. + + This is the key difference from LRU - we don't + update the position based on access. + """ + pass # Intentionally empty + + def select_victim(self, candidates: Set[int]) -> int: + """ + Select the earliest allocated block from candidates. + """ + if not candidates: + raise ValueError("Cannot select victim from empty candidate set") + + # Iterate from oldest (front) to newest (back) + for block_id in self.allocation_order: + if block_id in candidates: + return block_id + + # Fallback: return any candidate + return next(iter(candidates)) + + def on_block_evicted(self, block_id: int) -> None: + """Remove block from tracking.""" + self.allocation_order.pop(block_id, None) + + def on_block_prefetched(self, block_id: int, step: int) -> None: + """ + When prefetched, treat as new allocation. + + This moves the block to the end of the queue, + giving it more time before eviction. + """ + # Remove old entry if exists + self.allocation_order.pop(block_id, None) + # Add as new allocation + self.allocation_order[block_id] = step + + def on_block_deallocated(self, block_id: int) -> None: + """Remove block from tracking.""" + self.allocation_order.pop(block_id, None) + + def reset(self) -> None: + """Clear all tracking data.""" + self.allocation_order.clear() + + def get_eviction_order(self, candidates: Set[int], count: int) -> list: + """ + Get multiple blocks to evict in FIFO order. + """ + result = [] + for block_id in self.allocation_order: + if block_id in candidates: + result.append(block_id) + if len(result) >= count: + break + return result + + def __repr__(self) -> str: + return f"FIFOPolicy(tracked_blocks={len(self.allocation_order)})" \ No newline at end of file diff --git a/nanovllm/kvcache/policies/lru_policy.py b/nanovllm/kvcache/policies/lru_policy.py new file mode 100644 index 0000000..0776191 --- /dev/null +++ b/nanovllm/kvcache/policies/lru_policy.py @@ -0,0 +1,93 @@ +""" +LRU (Least Recently Used) eviction policy. + +Evicts the block that was accessed least recently. +This is the default and recommended policy for most use cases. +""" + +from collections import OrderedDict +from typing import Set + +from nanovllm.kvcache.policies.base_policy import EvictionPolicy + + +class LRUPolicy(EvictionPolicy): + """ + Least Recently Used (LRU) eviction policy. + + Maintains an ordered dictionary of block access times. + When eviction is needed, selects the block that was + accessed least recently. + + Properties: + - O(1) access tracking + - O(n) victim selection in worst case, but typically fast + due to OrderedDict iteration order + - Good for workloads with temporal locality + """ + + def __init__(self): + # OrderedDict maintains insertion/update order + # Key: block_id, Value: last_access_step + # Oldest (least recently used) is at the front + self.access_order: OrderedDict[int, int] = OrderedDict() + + def on_block_allocated(self, block_id: int, step: int) -> None: + """Record allocation as an access.""" + # Move to end (most recently used) + self.access_order[block_id] = step + self.access_order.move_to_end(block_id) + + def on_block_access(self, block_id: int, step: int) -> None: + """Update access time and move to end.""" + if block_id in self.access_order: + self.access_order[block_id] = step + self.access_order.move_to_end(block_id) + + def select_victim(self, candidates: Set[int]) -> int: + """ + Select the least recently used block from candidates. + + Iterates from oldest to newest in access order, + returns the first one that's in the candidate set. + """ + if not candidates: + raise ValueError("Cannot select victim from empty candidate set") + + # Iterate from oldest (front) to newest (back) + for block_id in self.access_order: + if block_id in candidates: + return block_id + + # Fallback: return any candidate (shouldn't happen normally) + return next(iter(candidates)) + + def on_block_evicted(self, block_id: int) -> None: + """Remove block from tracking.""" + self.access_order.pop(block_id, None) + + def on_block_deallocated(self, block_id: int) -> None: + """Remove block from tracking.""" + self.access_order.pop(block_id, None) + + def reset(self) -> None: + """Clear all tracking data.""" + self.access_order.clear() + + def get_eviction_order(self, candidates: Set[int], count: int) -> list: + """ + Efficiently get multiple blocks to evict in LRU order. + + Optimized for batch eviction - iterates through access_order + once instead of calling select_victim() multiple times. + """ + result = [] + for block_id in self.access_order: + if block_id in candidates: + result.append(block_id) + if len(result) >= count: + break + return result + + def __repr__(self) -> str: + return f"LRUPolicy(tracked_blocks={len(self.access_order)})" \ No newline at end of file diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index f5046b3..39bdc38 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -55,21 +55,164 @@ class Attention(nn.Module): self.scale = scale self.num_kv_heads = num_kv_heads self.k_cache = self.v_cache = torch.tensor([]) + # Layer ID set by model_runner after model creation + self.layer_id: int = -1 def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): context = get_context() k_cache, v_cache = self.k_cache, self.v_cache if k_cache.numel() and v_cache.numel(): store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) + if context.is_prefill: - if context.block_tables is not None: # prefix cache + if context.is_chunked_prefill: + # Chunked prefill: merge attention from previous KV + o = self._chunked_prefill_attention(q, k, v, context) + elif context.block_tables is not None: # prefix cache k, v = k_cache, v_cache - o = flash_attn_varlen_func(q, k, v, - max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, - max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, - softmax_scale=self.scale, causal=True, block_table=context.block_tables) + o = flash_attn_varlen_func(q, k, v, + max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, + max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, + softmax_scale=self.scale, causal=True, block_table=context.block_tables) + else: + o = flash_attn_varlen_func(q, k, v, + max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, + max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, + softmax_scale=self.scale, causal=True, block_table=context.block_tables) else: # decode - o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, - cache_seqlens=context.context_lens, block_table=context.block_tables, - softmax_scale=self.scale, causal=True) + if context.is_chunked_prefill: + # Chunked decode: need to load all KV from CPU+GPU + o = self._chunked_decode_attention(q, k, v, context) + else: + o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, + cache_seqlens=context.context_lens, block_table=context.block_tables, + softmax_scale=self.scale, causal=True) return o + + def _chunked_prefill_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + context, + ) -> torch.Tensor: + """ + Compute attention with chunked KV from CPU cache. + + For chunked prefill: + 1. Load previous KV from CPU for this layer + 2. Compute attention against previous KV (no causal mask) + 3. Compute attention against current chunk's KV (causal) + 4. Merge results using online softmax + """ + from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + + # q, k, v shape: [total_tokens, num_heads, head_dim] + total_tokens = q.shape[0] + + # Reshape for flash attention: [batch, seq, heads, dim] + q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim] + k_batched = k.unsqueeze(0) + v_batched = v.unsqueeze(0) + + accumulated_o = None + accumulated_lse = None + + # Load previous KV from CPU for this layer + if context.offload_engine is not None and self.layer_id >= 0: + # Get the kvcache_manager from context + kvcache_manager = context.offload_engine + + # For each sequence in the chunk, load previous KV + # Currently assuming single sequence + if hasattr(context, 'chunked_seq') and context.chunked_seq is not None: + prev_k, prev_v = kvcache_manager.load_prev_kv_for_layer( + context.chunked_seq, + self.layer_id, + ) + + if prev_k is not None and prev_v is not None: + # Compute attention against previous KV (no causal mask) + prev_o, prev_lse = flash_attn_with_lse( + q_batched, + prev_k, + prev_v, + softmax_scale=self.scale, + causal=False, # No causal mask for previous context + ) + accumulated_o = prev_o + accumulated_lse = prev_lse + + # Compute attention against current chunk's KV (with causal mask) + current_o, current_lse = flash_attn_with_lse( + q_batched, + k_batched, + v_batched, + softmax_scale=self.scale, + causal=True, # Causal mask for current chunk + ) + + # Merge with accumulated + if accumulated_o is None: + final_o = current_o + else: + final_o, _ = merge_attention_outputs( + accumulated_o, accumulated_lse, + current_o, current_lse, + ) + + # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] + return final_o.squeeze(0) + + def _chunked_decode_attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + context, + ) -> torch.Tensor: + """ + Compute decode attention with KV spread across CPU and GPU. + + For decode with chunked KV: + 1. Load all KV for this layer from CPU+GPU + 2. Compute attention (1 query token vs all KV) + 3. Return output + """ + from nanovllm.kvcache.chunked_attention import flash_attn_with_lse + + # q shape: [batch_size, num_heads, head_dim] (single decode token per sequence) + # We need to attend to ALL previous tokens + + # Load all KV for this layer + if context.offload_engine is not None and self.layer_id >= 0: + kvcache_manager = context.offload_engine + + if hasattr(context, 'chunked_seq') and context.chunked_seq is not None: + # Load all KV from both GPU and CPU for this layer + k_all, v_all = kvcache_manager.load_all_kv_for_layer( + context.chunked_seq, + self.layer_id, + ) + + if k_all is not None and v_all is not None: + # q shape: [batch_size, num_heads, head_dim] + # Need: [batch, seqlen, heads, dim] + # Insert seqlen dimension at position 1 + q_batched = q.unsqueeze(1) # [batch, 1, heads, dim] + + # k_all, v_all shape: [1, total_kv_tokens, kv_heads, head_dim] + # Compute attention (no causal mask for decode - we want all KV) + out, _ = flash_attn_with_lse( + q_batched, + k_all, + v_all, + softmax_scale=self.scale, + causal=False, # No causal mask for decode + ) + + # Output shape: [batch, 1, heads, dim] -> [batch, heads, dim] + return out.squeeze(1) + + # Fallback: shouldn't reach here + raise RuntimeError("Chunked decode attention failed: no KV available") diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index 2281888..b6b09a4 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -1,4 +1,5 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field +from typing import Optional, List, Tuple, Any import torch @@ -13,14 +14,60 @@ class Context: context_lens: torch.Tensor | None = None block_tables: torch.Tensor | None = None + # Chunked prefill support + is_chunked_prefill: bool = False + # Previous KV chunks info: List of (start_pos, end_pos) for blocks on CPU + prev_kv_ranges: List[Tuple[int, int]] = field(default_factory=list) + # Current chunk's position offset (for causal mask) + chunk_offset: int = 0 + # Reference to kvcache manager for loading previous KV (HybridKVCacheManager) + offload_engine: Any = None + # Current layer's previous K/V chunks (loaded from CPU) + # Set by model_runner before each layer's forward + prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list) + # Current sequence being processed (for chunked prefill to load KV) + chunked_seq: Any = None + + _CONTEXT = Context() + def get_context(): return _CONTEXT -def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None): + +def set_context( + is_prefill, + cu_seqlens_q=None, + cu_seqlens_k=None, + max_seqlen_q=0, + max_seqlen_k=0, + slot_mapping=None, + context_lens=None, + block_tables=None, + is_chunked_prefill=False, + prev_kv_ranges=None, + chunk_offset=0, + offload_engine=None, + chunked_seq=None, +): global _CONTEXT - _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) + _CONTEXT = Context( + is_prefill=is_prefill, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + slot_mapping=slot_mapping, + context_lens=context_lens, + block_tables=block_tables, + is_chunked_prefill=is_chunked_prefill, + prev_kv_ranges=prev_kv_ranges or [], + chunk_offset=chunk_offset, + offload_engine=offload_engine, + chunked_seq=chunked_seq, + ) + def reset_context(): global _CONTEXT diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..aa26d43 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test suite for nano-vllm KV cache offload.""" diff --git a/tests/test_kernels.py b/tests/test_kernels.py new file mode 100644 index 0000000..af2dc59 --- /dev/null +++ b/tests/test_kernels.py @@ -0,0 +1,169 @@ +"""Tests for Triton gathered copy kernels.""" + +import pytest +import torch + +from nanovllm.kvcache.kernels import gathered_copy, gathered_copy_kv + + +class TestGatheredCopy: + """Tests for gathered copy kernel.""" + + @pytest.fixture + def setup_tensors(self): + """Create test tensors.""" + torch.cuda.manual_seed(42) + num_src_blocks = 16 + num_dst_blocks = 8 + block_size = 256 + kv_dim = 64 + + src = torch.randn(num_src_blocks, block_size, kv_dim, + dtype=torch.float16, device="cuda") + dst = torch.zeros(num_dst_blocks, block_size, kv_dim, + dtype=torch.float16, device="cuda") + + # Indices: dst[i] = src[indices[i]] + indices = torch.randint(0, num_src_blocks, (num_dst_blocks,), + dtype=torch.int64, device="cuda") + + return src, dst, indices + + def test_basic_copy(self, setup_tensors): + """Test basic gathered copy.""" + src, dst, indices = setup_tensors + + gathered_copy(src, dst, indices) + + # Verify copy + for i in range(len(indices)): + src_idx = indices[i].item() + assert torch.allclose(dst[i], src[src_idx]), f"Mismatch at index {i}" + + def test_skip_negative_indices(self, setup_tensors): + """Test that negative indices are skipped.""" + src, dst, indices = setup_tensors + + # Set some indices to -1 + indices[2] = -1 + indices[5] = -1 + + # Fill dst with a known value + dst.fill_(999.0) + + gathered_copy(src, dst, indices) + + # Skipped slots should be unchanged + assert (dst[2] == 999.0).all() + assert (dst[5] == 999.0).all() + + # Non-skipped slots should be copied + for i in [0, 1, 3, 4, 6, 7]: + src_idx = indices[i].item() + assert torch.allclose(dst[i], src[src_idx]) + + def test_single_block(self): + """Test copying a single block.""" + src = torch.randn(4, 256, 64, dtype=torch.float16, device="cuda") + dst = torch.zeros(1, 256, 64, dtype=torch.float16, device="cuda") + indices = torch.tensor([2], dtype=torch.int64, device="cuda") + + gathered_copy(src, dst, indices) + + assert torch.allclose(dst[0], src[2]) + + +class TestGatheredCopyKV: + """Tests for gathered K/V cache copy kernel.""" + + @pytest.fixture + def setup_kv_tensors(self): + """Create K/V test tensors.""" + torch.cuda.manual_seed(42) + num_src_blocks = 16 + num_dst_blocks = 8 + block_size = 256 + num_kv_heads = 4 + head_dim = 64 + + k_src = torch.randn(num_src_blocks, block_size, num_kv_heads, head_dim, + dtype=torch.float16, device="cuda") + v_src = torch.randn(num_src_blocks, block_size, num_kv_heads, head_dim, + dtype=torch.float16, device="cuda") + k_dst = torch.zeros(num_dst_blocks, block_size, num_kv_heads, head_dim, + dtype=torch.float16, device="cuda") + v_dst = torch.zeros(num_dst_blocks, block_size, num_kv_heads, head_dim, + dtype=torch.float16, device="cuda") + + indices = torch.randint(0, num_src_blocks, (num_dst_blocks,), + dtype=torch.int64, device="cuda") + + return k_src, v_src, k_dst, v_dst, indices + + def test_kv_copy(self, setup_kv_tensors): + """Test K/V gathered copy.""" + k_src, v_src, k_dst, v_dst, indices = setup_kv_tensors + + gathered_copy_kv(k_src, v_src, k_dst, v_dst, indices) + + # Verify copy + for i in range(len(indices)): + src_idx = indices[i].item() + assert torch.allclose(k_dst[i], k_src[src_idx]), f"K mismatch at {i}" + assert torch.allclose(v_dst[i], v_src[src_idx]), f"V mismatch at {i}" + + def test_kv_skip_negative(self, setup_kv_tensors): + """Test that negative indices are skipped for K/V.""" + k_src, v_src, k_dst, v_dst, indices = setup_kv_tensors + + indices[0] = -1 + k_dst.fill_(999.0) + v_dst.fill_(999.0) + + gathered_copy_kv(k_src, v_src, k_dst, v_dst, indices) + + assert (k_dst[0] == 999.0).all() + assert (v_dst[0] == 999.0).all() + + +class TestPerformance: + """Performance benchmarks for gathered copy.""" + + @pytest.mark.parametrize("num_blocks", [8, 32, 128]) + def test_throughput(self, num_blocks): + """Benchmark copy throughput.""" + block_size = 256 + kv_dim = 64 + + src = torch.randn(num_blocks * 2, block_size, kv_dim, + dtype=torch.float16, device="cuda") + dst = torch.zeros(num_blocks, block_size, kv_dim, + dtype=torch.float16, device="cuda") + indices = torch.arange(num_blocks, dtype=torch.int64, device="cuda") + + # Warmup + for _ in range(10): + gathered_copy(src, dst, indices) + torch.cuda.synchronize() + + # Benchmark + import time + start = time.perf_counter() + num_iters = 100 + for _ in range(num_iters): + gathered_copy(src, dst, indices) + torch.cuda.synchronize() + elapsed = time.perf_counter() - start + + bytes_copied = num_blocks * block_size * kv_dim * 2 * num_iters # fp16 + bandwidth_gbps = bytes_copied / elapsed / 1e9 + + print(f"\n{num_blocks} blocks: {bandwidth_gbps:.2f} GB/s") + + # Should achieve reasonable bandwidth (lower threshold for small blocks due to kernel launch overhead) + min_bandwidth = 5 if num_blocks <= 16 else 10 + assert bandwidth_gbps > min_bandwidth, f"Bandwidth too low: {bandwidth_gbps} GB/s" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_kvcache_manager.py b/tests/test_kvcache_manager.py new file mode 100644 index 0000000..e3798b3 --- /dev/null +++ b/tests/test_kvcache_manager.py @@ -0,0 +1,175 @@ +"""Tests for KV cache managers.""" + +import pytest +import torch + +from nanovllm.engine.sequence import Sequence +from nanovllm.kvcache.gpu_manager import GPUOnlyManager + + +class MockSequence: + """Mock sequence for testing block allocation.""" + + def __init__(self, token_ids: list[int], block_size: int = 256): + self._token_ids = token_ids + self._block_size = block_size + self.block_table: list[int] = [] + self.num_cached_tokens = 0 + + def __len__(self): + return len(self._token_ids) + + @property + def num_blocks(self) -> int: + return (len(self) + self._block_size - 1) // self._block_size + + def block(self, i: int) -> list[int]: + start = i * self._block_size + end = min((i + 1) * self._block_size, len(self)) + return self._token_ids[start:end] + + +class TestGPUOnlyManager: + """Tests for GPU-only KV cache manager.""" + + @pytest.fixture + def manager(self): + """Create a small manager for testing.""" + return GPUOnlyManager(num_blocks=16, block_size=256) + + def test_initialization(self, manager): + """Test manager initialization.""" + assert manager.block_size == 256 + assert manager.num_free_blocks == 16 + assert len(manager.blocks) == 16 + + def test_allocate_cache(self, manager): + """Test cache allocation.""" + manager.allocate_cache( + num_layers=4, + num_kv_heads=8, + head_dim=64, + dtype=torch.float16, + ) + + assert manager.kv_cache is not None + assert manager.kv_cache.shape == (2, 4, 16, 256, 8, 64) + assert manager.kv_cache.device.type == "cuda" + + def test_get_layer_cache(self, manager): + """Test getting layer cache.""" + manager.allocate_cache( + num_layers=4, + num_kv_heads=8, + head_dim=64, + dtype=torch.float16, + ) + + k_cache, v_cache = manager.get_layer_cache(0) + assert k_cache.shape == (16, 256, 8, 64) + assert v_cache.shape == (16, 256, 8, 64) + + def test_can_allocate(self, manager): + """Test allocation check.""" + seq = MockSequence([0] * 300) # Needs 2 blocks + assert manager.can_allocate(seq) + + # Fill up all blocks with unique tokens to avoid prefix caching + for i in range(8): + # Each sequence has unique tokens to prevent prefix cache hits + s = MockSequence([i * 1000 + j for j in range(300)]) + manager.allocate(s) + + # Now should not be able to allocate + new_seq = MockSequence([9999] * 300) + assert not manager.can_allocate(new_seq) + + def test_allocate_and_deallocate(self, manager): + """Test block allocation and deallocation.""" + seq = MockSequence([0] * 600) # Needs 3 blocks + initial_free = manager.num_free_blocks + + manager.allocate(seq) + assert len(seq.block_table) == 3 + assert manager.num_free_blocks == initial_free - 3 + + manager.deallocate(seq) + assert len(seq.block_table) == 0 + assert manager.num_free_blocks == initial_free + + def test_can_append(self, manager): + """Test append check.""" + seq = MockSequence([0] * 256) # Exactly 1 block + manager.allocate(seq) + + # Can append without new block (still in same block) + seq._token_ids = [0] * 257 + assert manager.can_append(seq) + + def test_prepare_for_attention_noop(self, manager): + """Test that prepare_for_attention is a no-op for GPU-only.""" + seq = MockSequence([0] * 100) + manager.allocate(seq) + + # Should not raise + manager.prepare_for_attention([seq], is_prefill=True) + manager.prepare_for_attention([seq], is_prefill=False) + + def test_get_gpu_block_tables(self, manager): + """Test getting GPU block tables.""" + seq1 = MockSequence([0] * 300) + seq2 = MockSequence([0] * 600) + + manager.allocate(seq1) + manager.allocate(seq2) + + tables = manager.get_gpu_block_tables([seq1, seq2]) + + assert len(tables) == 2 + assert tables[0] == list(seq1.block_table) + assert tables[1] == list(seq2.block_table) + + +class TestGPUOnlyManagerPrefixCaching: + """Tests for prefix caching in GPU-only manager.""" + + @pytest.fixture + def manager(self): + """Create manager for testing.""" + return GPUOnlyManager(num_blocks=32, block_size=256) + + def test_prefix_cache_hit(self, manager): + """Test that identical prefixes are cached.""" + # Create two sequences with same prefix + tokens = list(range(512)) # 2 full blocks + seq1 = MockSequence(tokens) + seq2 = MockSequence(tokens) + + manager.allocate(seq1) + initial_free = manager.num_free_blocks + + manager.allocate(seq2) + + # Second sequence should reuse cached blocks + assert seq2.num_cached_tokens >= 256 # At least first block cached + # Should use fewer new blocks + assert manager.num_free_blocks >= initial_free - 2 + + def test_prefix_cache_different_suffix(self, manager): + """Test cache with same prefix but different suffix.""" + prefix = list(range(256)) # 1 full block + + seq1 = MockSequence(prefix + [1000, 1001]) + seq2 = MockSequence(prefix + [2000, 2001]) + + manager.allocate(seq1) + manager.allocate(seq2) + + # First block should be shared + assert seq1.block_table[0] == seq2.block_table[0] + # Second block should be different + assert seq1.block_table[1] != seq2.block_table[1] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_offload_engine.py b/tests/test_offload_engine.py new file mode 100644 index 0000000..8613ee5 --- /dev/null +++ b/tests/test_offload_engine.py @@ -0,0 +1,196 @@ +"""Tests for CPU-GPU offload engine.""" + +import pytest +import torch + +from nanovllm.kvcache.offload_engine import OffloadEngine + + +class TestOffloadEngine: + """Tests for OffloadEngine.""" + + @pytest.fixture + def engine(self): + """Create a small engine for testing.""" + return OffloadEngine( + num_layers=2, + num_gpu_blocks=4, + num_cpu_blocks=8, + block_size=256, + num_kv_heads=4, + head_dim=64, + dtype=torch.float16, + num_streams=2, + ) + + def test_initialization(self, engine): + """Test engine initialization.""" + # Check GPU cache shape + assert engine.k_cache_gpu.shape == (2, 4, 256, 4, 64) + assert engine.v_cache_gpu.shape == (2, 4, 256, 4, 64) + + # Check CPU cache shape + assert engine.k_cache_cpu.shape == (2, 8, 256, 4, 64) + assert engine.v_cache_cpu.shape == (2, 8, 256, 4, 64) + + # Check pinned memory + assert engine.k_cache_cpu.is_pinned() + assert engine.v_cache_cpu.is_pinned() + + # Check gather indices + assert engine.gather_indices_cpu.shape == (2, 4) + assert engine.gather_indices_gpu.shape == (2, 4) + + def test_get_layer_cache(self, engine): + """Test getting layer cache.""" + k, v = engine.get_layer_cache(0) + assert k.shape == (4, 256, 4, 64) + assert v.shape == (4, 256, 4, 64) + assert k.device.type == "cuda" + assert v.device.type == "cuda" + + def test_prefetch_and_offload(self, engine): + """Test async prefetch and offload.""" + # Write some data to CPU block 0 + engine.k_cache_cpu[0, 0].fill_(1.0) + engine.v_cache_cpu[0, 0].fill_(2.0) + + # Prefetch to GPU block 2 + event = engine.prefetch_block_async( + layer_id=0, + cpu_block_id=0, + gpu_block_id=2, + ) + event.synchronize() + + # Verify data was copied (move GPU to CPU for comparison) + assert torch.allclose(engine.k_cache_gpu[0, 2].cpu(), engine.k_cache_cpu[0, 0]) + assert torch.allclose(engine.v_cache_gpu[0, 2].cpu(), engine.v_cache_cpu[0, 0]) + + # Modify GPU data + engine.k_cache_gpu[0, 2].fill_(3.0) + engine.v_cache_gpu[0, 2].fill_(4.0) + + # Offload to CPU block 5 + event = engine.offload_block_async( + layer_id=0, + gpu_block_id=2, + cpu_block_id=5, + ) + event.synchronize() + + # Verify data was copied + assert torch.allclose(engine.k_cache_cpu[0, 5], engine.k_cache_gpu[0, 2].cpu()) + assert torch.allclose(engine.v_cache_cpu[0, 5], engine.v_cache_gpu[0, 2].cpu()) + + def test_update_gather_indices(self, engine): + """Test updating gather indices.""" + # Manually set CPU data + for i in range(8): + engine.k_cache_cpu[0, i].fill_(float(i)) + engine.v_cache_cpu[0, i].fill_(float(i + 100)) + + # Update indices for layer 0: (cpu_block_id, gpu_slot) + mappings = [(2, 0), (5, 1), (1, 2), (7, 3)] + engine.update_gather_indices(layer_id=0, mappings=mappings) + torch.cuda.synchronize() + + # Verify indices were set + expected = torch.tensor([2, 5, 1, 7], dtype=torch.int64) + assert torch.equal(engine.gather_indices_cpu[0], expected) + + def test_gathered_h2d_layer(self, engine): + """Test gathered H2D copy for a layer.""" + # Set up CPU data with known values + for i in range(8): + engine.k_cache_cpu[0, i].fill_(float(i)) + engine.v_cache_cpu[0, i].fill_(float(i + 100)) + + # Set gather indices: (cpu_block_id, gpu_slot) + # GPU slot 0 gets CPU block 3, GPU slot 1 gets CPU block 0, etc. + mappings = [(3, 0), (0, 1), (7, 2), (2, 3)] + engine.update_gather_indices(layer_id=0, mappings=mappings) + torch.cuda.synchronize() + + # Execute gathered H2D + engine.gathered_h2d_layer(layer_id=0) + torch.cuda.synchronize() + + # Verify: GPU slot 0 should have CPU block 3's data + assert torch.allclose(engine.k_cache_gpu[0, 0], + torch.full_like(engine.k_cache_gpu[0, 0], 3.0)) + # GPU slot 1 should have CPU block 0's data + assert torch.allclose(engine.k_cache_gpu[0, 1], + torch.full_like(engine.k_cache_gpu[0, 1], 0.0)) + # GPU slot 2 should have CPU block 7's data + assert torch.allclose(engine.k_cache_gpu[0, 2], + torch.full_like(engine.k_cache_gpu[0, 2], 7.0)) + # GPU slot 3 should have CPU block 2's data + assert torch.allclose(engine.k_cache_gpu[0, 3], + torch.full_like(engine.k_cache_gpu[0, 3], 2.0)) + + def test_multi_layer_independence(self, engine): + """Test that layers are independent.""" + # Set different data for each layer + engine.k_cache_cpu[0, 0].fill_(1.0) + engine.k_cache_cpu[1, 0].fill_(2.0) + + # Prefetch layer 0 + event = engine.prefetch_block_async(0, 0, 0) + event.synchronize() + + # Verify only layer 0 was affected + assert torch.allclose(engine.k_cache_gpu[0, 0], + torch.full_like(engine.k_cache_gpu[0, 0], 1.0)) + # Layer 1 should be zeros (initial state) + assert not torch.allclose(engine.k_cache_gpu[1, 0], + torch.full_like(engine.k_cache_gpu[1, 0], 2.0)) + + +class TestOffloadEngineFixedAddresses: + """Tests verifying fixed address property for CUDA Graph compatibility.""" + + @pytest.fixture + def engine(self): + """Create engine for address tests.""" + return OffloadEngine( + num_layers=2, + num_gpu_blocks=4, + num_cpu_blocks=8, + block_size=256, + num_kv_heads=4, + head_dim=64, + dtype=torch.float16, + num_streams=2, + ) + + def test_gpu_cache_address_fixed(self, engine): + """Verify GPU cache addresses don't change.""" + k_ptr_before = engine.k_cache_gpu.data_ptr() + v_ptr_before = engine.v_cache_gpu.data_ptr() + + # Perform some operations - mappings is List[(cpu_block_id, gpu_slot)] + mappings = [(0, 0), (1, 1), (2, 2), (3, 3)] + engine.update_gather_indices(0, mappings) + engine.gathered_h2d_layer(0) + torch.cuda.synchronize() + + # Addresses should be the same + assert engine.k_cache_gpu.data_ptr() == k_ptr_before + assert engine.v_cache_gpu.data_ptr() == v_ptr_before + + def test_gather_indices_gpu_address_fixed(self, engine): + """Verify gather indices GPU tensor address doesn't change.""" + ptr_before = engine.gather_indices_gpu.data_ptr() + + # Update indices multiple times - mappings is List[(cpu_block_id, gpu_slot)] + mappings = [(0, 0), (1, 1), (2, 2), (3, 3)] + for _ in range(10): + engine.update_gather_indices(0, mappings) + torch.cuda.synchronize() + + assert engine.gather_indices_gpu.data_ptr() == ptr_before + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_policies.py b/tests/test_policies.py new file mode 100644 index 0000000..d241148 --- /dev/null +++ b/tests/test_policies.py @@ -0,0 +1,167 @@ +"""Tests for eviction policies.""" + +import pytest +from nanovllm.kvcache.policies.lru_policy import LRUPolicy +from nanovllm.kvcache.policies.fifo_policy import FIFOPolicy +from nanovllm.kvcache.policies import get_policy + + +class TestLRUPolicy: + """Tests for LRU eviction policy.""" + + def test_basic_eviction(self): + """Test that LRU evicts least recently used block.""" + policy = LRUPolicy() + + # Allocate blocks 0, 1, 2 in order + policy.on_block_allocated(0, step=1) + policy.on_block_allocated(1, step=2) + policy.on_block_allocated(2, step=3) + + # Access block 0 (makes it most recently used) + policy.on_block_access(0, step=4) + + # Should evict block 1 (least recently used) + candidates = {0, 1, 2} + victim = policy.select_victim(candidates) + assert victim == 1, f"Expected block 1, got {victim}" + + def test_access_updates_order(self): + """Test that access updates LRU order.""" + policy = LRUPolicy() + + policy.on_block_allocated(0, step=1) + policy.on_block_allocated(1, step=2) + policy.on_block_allocated(2, step=3) + + # Access all in reverse order + policy.on_block_access(2, step=4) + policy.on_block_access(1, step=5) + policy.on_block_access(0, step=6) + + # Block 2 is now LRU (accessed earliest after allocation update) + candidates = {0, 1, 2} + victim = policy.select_victim(candidates) + assert victim == 2, f"Expected block 2, got {victim}" + + def test_eviction_removes_from_tracking(self): + """Test that evicted blocks are removed from tracking.""" + policy = LRUPolicy() + + policy.on_block_allocated(0, step=1) + policy.on_block_allocated(1, step=2) + + policy.on_block_evicted(0) + + # Only block 1 should be a candidate + candidates = {0, 1} + victim = policy.select_victim(candidates) + assert victim == 1, "Should select block 1 since 0 was evicted" + + def test_batch_eviction_order(self): + """Test get_eviction_order returns blocks in LRU order.""" + policy = LRUPolicy() + + for i in range(5): + policy.on_block_allocated(i, step=i) + + # Access blocks 2 and 4 + policy.on_block_access(2, step=10) + policy.on_block_access(4, step=11) + + candidates = {0, 1, 2, 3, 4} + order = policy.get_eviction_order(candidates, count=3) + + # Should be 0, 1, 3 (in that order, skipping 2 and 4 until needed) + assert order == [0, 1, 3], f"Expected [0, 1, 3], got {order}" + + +class TestFIFOPolicy: + """Tests for FIFO eviction policy.""" + + def test_basic_eviction(self): + """Test that FIFO evicts oldest allocated block.""" + policy = FIFOPolicy() + + policy.on_block_allocated(0, step=1) + policy.on_block_allocated(1, step=2) + policy.on_block_allocated(2, step=3) + + # Access doesn't change FIFO order + policy.on_block_access(0, step=4) + + candidates = {0, 1, 2} + victim = policy.select_victim(candidates) + assert victim == 0, f"Expected block 0 (oldest), got {victim}" + + def test_access_does_not_update_order(self): + """Test that FIFO ignores access patterns.""" + policy = FIFOPolicy() + + policy.on_block_allocated(0, step=1) + policy.on_block_allocated(1, step=2) + policy.on_block_allocated(2, step=3) + + # Multiple accesses to block 0 + for i in range(10): + policy.on_block_access(0, step=10 + i) + + # Block 0 should still be evicted first (FIFO order) + candidates = {0, 1, 2} + victim = policy.select_victim(candidates) + assert victim == 0, f"Expected block 0, got {victim}" + + def test_prefetch_resets_order(self): + """Test that prefetch moves block to end of queue.""" + policy = FIFOPolicy() + + policy.on_block_allocated(0, step=1) + policy.on_block_allocated(1, step=2) + policy.on_block_allocated(2, step=3) + + # Prefetch block 0 (moves to end) + policy.on_block_prefetched(0, step=4) + + candidates = {0, 1, 2} + victim = policy.select_victim(candidates) + assert victim == 1, f"Expected block 1 (now oldest), got {victim}" + + def test_batch_eviction_order(self): + """Test get_eviction_order returns blocks in FIFO order.""" + policy = FIFOPolicy() + + for i in range(5): + policy.on_block_allocated(i, step=i) + + candidates = {0, 1, 2, 3, 4} + order = policy.get_eviction_order(candidates, count=3) + + assert order == [0, 1, 2], f"Expected [0, 1, 2], got {order}" + + +class TestGetPolicy: + """Tests for policy factory function.""" + + def test_get_lru(self): + """Test getting LRU policy by name.""" + policy = get_policy("lru") + assert isinstance(policy, LRUPolicy) + + def test_get_fifo(self): + """Test getting FIFO policy by name.""" + policy = get_policy("fifo") + assert isinstance(policy, FIFOPolicy) + + def test_get_by_class_path(self): + """Test getting policy by full class path.""" + policy = get_policy("nanovllm.kvcache.policies.lru_policy.LRUPolicy") + assert isinstance(policy, LRUPolicy) + + def test_invalid_policy_name(self): + """Test that invalid policy name raises error.""" + with pytest.raises((ValueError, ImportError)): + get_policy("invalid_policy") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])