[feat] Added chunked prefill and kvcache offload mechenism.
This commit is contained in:
10
bench.py
10
bench.py
@@ -44,16 +44,16 @@ def main():
|
||||
print("Prefill Benchmark")
|
||||
print("=" * 60)
|
||||
bench_prefill(llm, num_seqs=1, input_len=1024)
|
||||
bench_prefill(llm, num_seqs=1, input_len=2048)
|
||||
bench_prefill(llm, num_seqs=1, input_len=4095)
|
||||
bench_prefill(llm, num_seqs=16, input_len=1024)
|
||||
bench_prefill(llm, num_seqs=64, input_len=1024)
|
||||
# bench_prefill(llm, num_seqs=1, input_len=2048)
|
||||
# bench_prefill(llm, num_seqs=1, input_len=4095)
|
||||
# bench_prefill(llm, num_seqs=16, input_len=1024)
|
||||
# bench_prefill(llm, num_seqs=64, input_len=1024)
|
||||
|
||||
print("=" * 60)
|
||||
print("Decode Benchmark")
|
||||
print("=" * 60)
|
||||
bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024)
|
||||
bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024)
|
||||
# bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
64
bench_offload.py
Normal file
64
bench_offload.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import os
|
||||
import time
|
||||
from random import randint, seed
|
||||
from nanovllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def bench_decode(llm, num_seqs, max_input_len, max_output_len):
|
||||
"""Benchmark decode performance"""
|
||||
seed(0)
|
||||
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
||||
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
|
||||
|
||||
t = time.time()
|
||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||
t = time.time() - t
|
||||
total_output_tokens = sum(sp.max_tokens for sp in sampling_params)
|
||||
throughput = total_output_tokens / t
|
||||
print(f"[Decode] Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||
|
||||
|
||||
def bench_prefill(llm, num_seqs, input_len):
|
||||
"""Benchmark prefill performance"""
|
||||
seed(0)
|
||||
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
||||
|
||||
t = time.time()
|
||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||
t = time.time() - t
|
||||
total_input_tokens = num_seqs * input_len
|
||||
throughput = total_input_tokens / t
|
||||
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||
|
||||
|
||||
def main():
|
||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||
llm = LLM(
|
||||
path,
|
||||
enforce_eager=True,
|
||||
max_model_len=128 * 1024,
|
||||
max_num_batched_tokens=128 * 1024,
|
||||
enable_cpu_offload=True,
|
||||
cpu_memory_gb=32.0,
|
||||
)
|
||||
|
||||
# Warmup
|
||||
llm.generate(["Benchmark: "], SamplingParams())
|
||||
|
||||
print("=" * 60)
|
||||
print("Prefill Benchmark (CPU Offload)")
|
||||
print("=" * 60)
|
||||
bench_prefill(llm, num_seqs=1, input_len=64*1024)
|
||||
# bench_prefill(llm, num_seqs=1, input_len=16384)
|
||||
# bench_prefill(llm, num_seqs=1, input_len=32000)
|
||||
|
||||
print("=" * 60)
|
||||
print("Decode Benchmark (CPU Offload)")
|
||||
print("=" * 60)
|
||||
bench_decode(llm, num_seqs=1, max_input_len=64*1024, max_output_len=256)
|
||||
# bench_decode(llm, num_seqs=1, max_input_len=16384, max_output_len=256)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -47,16 +47,16 @@ def main():
|
||||
print("Prefill Benchmark")
|
||||
print("=" * 60)
|
||||
bench_prefill(llm, num_seqs=1, input_len=1024)
|
||||
bench_prefill(llm, num_seqs=1, input_len=2048)
|
||||
bench_prefill(llm, num_seqs=1, input_len=4095)
|
||||
bench_prefill(llm, num_seqs=16, input_len=1024)
|
||||
bench_prefill(llm, num_seqs=64, input_len=1024)
|
||||
# bench_prefill(llm, num_seqs=1, input_len=2048)
|
||||
# bench_prefill(llm, num_seqs=1, input_len=4095)
|
||||
# bench_prefill(llm, num_seqs=16, input_len=1024)
|
||||
# bench_prefill(llm, num_seqs=64, input_len=1024)
|
||||
|
||||
print("=" * 60)
|
||||
print("Decode Benchmark")
|
||||
print("=" * 60)
|
||||
bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024)
|
||||
bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024)
|
||||
# bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -17,6 +17,16 @@ class Config:
|
||||
kvcache_block_size: int = 256
|
||||
num_kvcache_blocks: int = -1
|
||||
|
||||
# CPU Offload configuration
|
||||
enable_cpu_offload: bool = False
|
||||
cpu_memory_gb: float = 16.0 # CPU memory limit for KV cache
|
||||
offload_policy: str = "lru" # "lru", "fifo", or full class path
|
||||
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
|
||||
|
||||
# Computed fields for offload (set in __post_init__ or by ModelRunner)
|
||||
num_gpu_kvcache_blocks: int = -1
|
||||
num_cpu_kvcache_blocks: int = -1
|
||||
|
||||
def __post_init__(self):
|
||||
assert os.path.isdir(self.model)
|
||||
assert self.kvcache_block_size % 256 == 0
|
||||
|
||||
@@ -31,7 +31,7 @@ class LLMEngine:
|
||||
self.model_runner = ModelRunner(config, 0, self.events)
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
|
||||
config.eos = self.tokenizer.eos_token_id
|
||||
self.scheduler = Scheduler(config)
|
||||
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
|
||||
atexit.register(self.exit)
|
||||
|
||||
def exit(self):
|
||||
|
||||
@@ -10,6 +10,7 @@ from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||
from nanovllm.layers.sampler import Sampler
|
||||
from nanovllm.utils.context import set_context, get_context, reset_context
|
||||
from nanovllm.utils.loader import load_model
|
||||
from nanovllm.kvcache import create_kvcache_manager, KVCacheManager
|
||||
|
||||
|
||||
class ModelRunner:
|
||||
@@ -107,14 +108,45 @@ class ModelRunner:
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
|
||||
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||
assert config.num_kvcache_blocks > 0
|
||||
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
|
||||
|
||||
# Calculate GPU block count
|
||||
num_gpu_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||
assert num_gpu_blocks > 0
|
||||
|
||||
if config.enable_cpu_offload:
|
||||
# Calculate CPU blocks based on cpu_memory_gb
|
||||
cpu_bytes = int(config.cpu_memory_gb * 1024**3)
|
||||
num_cpu_blocks = cpu_bytes // block_bytes
|
||||
config.num_gpu_kvcache_blocks = num_gpu_blocks
|
||||
config.num_cpu_kvcache_blocks = num_cpu_blocks
|
||||
# For backward compatibility
|
||||
config.num_kvcache_blocks = num_gpu_blocks + num_cpu_blocks
|
||||
else:
|
||||
config.num_kvcache_blocks = num_gpu_blocks
|
||||
config.num_gpu_kvcache_blocks = num_gpu_blocks
|
||||
config.num_cpu_kvcache_blocks = 0
|
||||
|
||||
# Create KV cache manager using factory
|
||||
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
|
||||
|
||||
# Allocate cache through manager
|
||||
self.kvcache_manager.allocate_cache(
|
||||
num_layers=hf_config.num_hidden_layers,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=hf_config.torch_dtype,
|
||||
)
|
||||
|
||||
# Bind layer caches to attention modules and set layer_id
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||
module.k_cache = self.kv_cache[0, layer_id]
|
||||
module.v_cache = self.kv_cache[1, layer_id]
|
||||
k_cache, v_cache = self.kvcache_manager.get_layer_cache(layer_id)
|
||||
module.k_cache = k_cache
|
||||
module.v_cache = v_cache
|
||||
# Set layer_id for chunked prefill support
|
||||
if hasattr(module, "layer_id"):
|
||||
module.layer_id = layer_id
|
||||
layer_id += 1
|
||||
|
||||
def prepare_block_tables(self, seqs: list[Sequence]):
|
||||
@@ -123,7 +155,30 @@ class ModelRunner:
|
||||
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
return block_tables
|
||||
|
||||
def prepare_prefill(self, seqs: list[Sequence]):
|
||||
def prepare_prefill(self, seqs: list[Sequence], chunk_info: list[tuple] = None):
|
||||
"""
|
||||
Prepare inputs for prefill.
|
||||
|
||||
Args:
|
||||
seqs: List of sequences to prefill
|
||||
chunk_info: Optional chunked prefill info from get_gpu_block_tables_partial().
|
||||
If provided, only process blocks in the chunk.
|
||||
Format: [(gpu_block_ids, start_block_idx, end_block_idx), ...]
|
||||
"""
|
||||
# Check if any sequence has blocks (not warmup)
|
||||
has_blocks = any(seq.block_table for seq in seqs)
|
||||
|
||||
gpu_block_tables = None
|
||||
if has_blocks and hasattr(self, 'kvcache_manager'):
|
||||
if chunk_info is None:
|
||||
# Standard prefill - try to get all blocks
|
||||
# This may fail if GPU doesn't have enough capacity
|
||||
self.kvcache_manager.prepare_for_attention(seqs, is_prefill=True)
|
||||
gpu_block_tables = self.kvcache_manager.get_gpu_block_tables(seqs)
|
||||
else:
|
||||
# Chunked prefill - use provided chunk info
|
||||
gpu_block_tables = [info[0] for info in chunk_info]
|
||||
|
||||
input_ids = []
|
||||
positions = []
|
||||
cu_seqlens_q = [0]
|
||||
@@ -132,27 +187,67 @@ class ModelRunner:
|
||||
max_seqlen_k = 0
|
||||
slot_mapping = []
|
||||
block_tables = None
|
||||
for seq in seqs:
|
||||
seqlen = len(seq)
|
||||
input_ids.extend(seq[seq.num_cached_tokens:])
|
||||
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
|
||||
seqlen_q = seqlen - seq.num_cached_tokens
|
||||
seqlen_k = seqlen
|
||||
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||
if not seq.block_table: # warmup
|
||||
continue
|
||||
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
||||
start = seq.block_table[i] * self.block_size
|
||||
if i != seq.num_blocks - 1:
|
||||
end = start + self.block_size
|
||||
else:
|
||||
end = start + seq.last_block_num_tokens
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||
block_tables = self.prepare_block_tables(seqs)
|
||||
|
||||
for seq_idx, seq in enumerate(seqs):
|
||||
if chunk_info is not None:
|
||||
# Chunked prefill: only process blocks in the chunk
|
||||
gpu_blocks, start_block_idx, end_block_idx = chunk_info[seq_idx]
|
||||
if not gpu_blocks:
|
||||
continue
|
||||
|
||||
# Calculate token range for this chunk
|
||||
start_token = start_block_idx * self.block_size
|
||||
end_token = min(end_block_idx * self.block_size, len(seq))
|
||||
if end_block_idx == seq.num_blocks:
|
||||
# Last chunk includes partial last block
|
||||
end_token = len(seq)
|
||||
|
||||
# Input tokens for this chunk
|
||||
chunk_tokens = seq[start_token:end_token]
|
||||
input_ids.extend(chunk_tokens)
|
||||
positions.extend(list(range(start_token, end_token)))
|
||||
|
||||
seqlen_q = end_token - start_token
|
||||
seqlen_k = end_token # Context includes all tokens up to this point
|
||||
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||
|
||||
# Slot mapping for blocks in this chunk
|
||||
for i, gpu_block_id in enumerate(gpu_blocks):
|
||||
block_idx = start_block_idx + i
|
||||
start = gpu_block_id * self.block_size
|
||||
if block_idx != seq.num_blocks - 1:
|
||||
end = start + self.block_size
|
||||
else:
|
||||
end = start + seq.last_block_num_tokens
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
else:
|
||||
# Standard prefill
|
||||
seqlen = len(seq)
|
||||
input_ids.extend(seq[seq.num_cached_tokens:])
|
||||
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
|
||||
seqlen_q = seqlen - seq.num_cached_tokens
|
||||
seqlen_k = seqlen
|
||||
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
|
||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||
if not seq.block_table: # warmup
|
||||
continue
|
||||
# Use GPU physical block IDs for slot mapping
|
||||
gpu_blocks = gpu_block_tables[seq_idx]
|
||||
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
||||
start = gpu_blocks[i] * self.block_size
|
||||
if i != seq.num_blocks - 1:
|
||||
end = start + self.block_size
|
||||
else:
|
||||
end = start + seq.last_block_num_tokens
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1] and gpu_block_tables: # prefix cache
|
||||
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
@@ -162,23 +257,40 @@ class ModelRunner:
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
# Prepare KV cache (updates gather_indices for hybrid manager)
|
||||
if hasattr(self, 'kvcache_manager'):
|
||||
self.kvcache_manager.prepare_for_attention(seqs, is_prefill=False)
|
||||
# Get GPU physical block tables
|
||||
gpu_block_tables = self.kvcache_manager.get_gpu_block_tables(seqs)
|
||||
else:
|
||||
gpu_block_tables = [list(seq.block_table) for seq in seqs]
|
||||
|
||||
input_ids = []
|
||||
positions = []
|
||||
slot_mapping = []
|
||||
context_lens = []
|
||||
for seq in seqs:
|
||||
for seq_idx, seq in enumerate(seqs):
|
||||
input_ids.append(seq.last_token)
|
||||
positions.append(len(seq) - 1)
|
||||
context_lens.append(len(seq))
|
||||
slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1)
|
||||
# Use GPU physical block ID for slot mapping
|
||||
gpu_blocks = gpu_block_tables[seq_idx]
|
||||
slot_mapping.append(gpu_blocks[-1] * self.block_size + seq.last_block_num_tokens - 1)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
block_tables = self.prepare_block_tables(seqs)
|
||||
# Use GPU physical block tables for attention
|
||||
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
|
||||
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
||||
return input_ids, positions
|
||||
|
||||
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
|
||||
"""Prepare block tables tensor from GPU physical block IDs."""
|
||||
max_len = max(len(bt) for bt in gpu_block_tables)
|
||||
padded = [bt + [-1] * (max_len - len(bt)) for bt in gpu_block_tables]
|
||||
return torch.tensor(padded, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
def prepare_sample(self, seqs: list[Sequence]):
|
||||
temperatures = []
|
||||
for seq in seqs:
|
||||
@@ -206,6 +318,26 @@ class ModelRunner:
|
||||
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||
|
||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||
# Check if chunked prefill is needed
|
||||
if is_prefill and hasattr(self, 'kvcache_manager'):
|
||||
needs_chunked = any(
|
||||
hasattr(self.kvcache_manager, 'needs_chunked_prefill') and
|
||||
self.kvcache_manager.needs_chunked_prefill(seq)
|
||||
for seq in seqs if seq.block_table
|
||||
)
|
||||
if needs_chunked:
|
||||
return self.run_chunked_prefill(seqs)
|
||||
|
||||
# Check if chunked decode is needed
|
||||
if not is_prefill and hasattr(self, 'kvcache_manager'):
|
||||
needs_chunked = any(
|
||||
hasattr(self.kvcache_manager, 'needs_chunked_decode') and
|
||||
self.kvcache_manager.needs_chunked_decode(seq)
|
||||
for seq in seqs if seq.block_table
|
||||
)
|
||||
if needs_chunked:
|
||||
return self.run_chunked_decode(seqs)
|
||||
|
||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
logits = self.run_model(input_ids, positions, is_prefill)
|
||||
@@ -213,6 +345,204 @@ class ModelRunner:
|
||||
reset_context()
|
||||
return token_ids
|
||||
|
||||
def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run prefill in chunks when sequences exceed GPU capacity.
|
||||
|
||||
For each chunk:
|
||||
1. Process tokens through model forward pass
|
||||
2. At each attention layer:
|
||||
- Load previous KV from CPU (handled by attention layer)
|
||||
- Compute attention with online softmax merging
|
||||
- Store current KV to GPU cache
|
||||
3. After chunk completes, offload KV to CPU
|
||||
4. Load next chunk's blocks to GPU
|
||||
"""
|
||||
import sys
|
||||
|
||||
# Currently only supporting single sequence for chunked prefill
|
||||
assert len(seqs) == 1, "Chunked prefill only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
total_blocks = seq.num_blocks
|
||||
print(f"[Chunked Prefill] Starting: {total_blocks} total blocks, "
|
||||
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
|
||||
|
||||
chunk_num = 0
|
||||
logits = None
|
||||
|
||||
while True:
|
||||
# Get chunk info (which blocks are on GPU and not yet prefilled)
|
||||
chunk_info = self.kvcache_manager.get_gpu_block_tables_partial(seqs)
|
||||
gpu_blocks, start_block_idx, end_block_idx = chunk_info[0]
|
||||
|
||||
if not gpu_blocks:
|
||||
# No more blocks to process
|
||||
break
|
||||
|
||||
chunk_num += 1
|
||||
chunk_tokens = (end_block_idx - start_block_idx) * self.block_size
|
||||
if end_block_idx == seq.num_blocks:
|
||||
# Last block may be partial
|
||||
chunk_tokens = len(seq) - start_block_idx * self.block_size
|
||||
|
||||
print(f"[Chunked Prefill] Chunk {chunk_num}: blocks {start_block_idx}-{end_block_idx-1}, "
|
||||
f"~{chunk_tokens} tokens", file=sys.stderr)
|
||||
|
||||
# Prepare inputs for this chunk
|
||||
input_ids, positions = self._prepare_chunked_prefill(seq, gpu_blocks, start_block_idx, end_block_idx)
|
||||
|
||||
if input_ids.numel() == 0:
|
||||
print(f"[Chunked Prefill] No input tokens, breaking", file=sys.stderr)
|
||||
break
|
||||
|
||||
print(f"[Chunked Prefill] Running model with {input_ids.numel()} tokens...", file=sys.stderr)
|
||||
|
||||
# Run model forward pass
|
||||
logits = self.run_model(input_ids, positions, is_prefill=True)
|
||||
reset_context()
|
||||
|
||||
print(f"[Chunked Prefill] Model forward complete", file=sys.stderr)
|
||||
|
||||
# Check if this is the last chunk
|
||||
# Mark current chunk as prefilled and offload to CPU
|
||||
self.kvcache_manager.complete_prefill_chunk(seq)
|
||||
|
||||
# Check if more chunks needed
|
||||
if not self.kvcache_manager.needs_chunked_prefill(seq):
|
||||
print(f"[Chunked Prefill] All chunks done, sampling", file=sys.stderr)
|
||||
break
|
||||
|
||||
print(f"[Chunked Prefill] Chunk transfer complete, loading next...", file=sys.stderr)
|
||||
|
||||
# Sample from the last chunk's logits
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
if logits is not None:
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
else:
|
||||
token_ids = [0] if self.rank == 0 else None
|
||||
|
||||
return token_ids
|
||||
|
||||
def run_chunked_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||
"""
|
||||
Run decode with chunked attention when sequence exceeds GPU capacity.
|
||||
|
||||
For decode, we need attention over ALL previous tokens. With CPU offload,
|
||||
we load KV chunks and compute attention incrementally.
|
||||
"""
|
||||
import sys
|
||||
|
||||
# Currently only supporting single sequence for chunked decode
|
||||
assert len(seqs) == 1, "Chunked decode only supports single sequence"
|
||||
seq = seqs[0]
|
||||
|
||||
total_blocks = len(seq.block_table)
|
||||
print(f"[Chunked Decode] Sequence has {total_blocks} blocks, "
|
||||
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
|
||||
|
||||
# Prepare inputs
|
||||
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Compute slot mapping for the new token
|
||||
# Get the last block's GPU slot if it's on GPU, otherwise we need to handle it
|
||||
last_logical_id = seq.block_table[-1]
|
||||
last_block = self.kvcache_manager.logical_blocks[last_logical_id]
|
||||
|
||||
if last_block.location.name == "GPU":
|
||||
slot = last_block.gpu_slot * self.block_size + seq.last_block_num_tokens - 1
|
||||
else:
|
||||
# Last block is on CPU - we need to bring it to GPU for writing the new token
|
||||
# This is a special case - allocate a temporary GPU slot
|
||||
# For simplicity, use a fixed slot (this might conflict, but for decode
|
||||
# we only write 1 token so it should be ok)
|
||||
print(f"[Chunked Decode] Warning: last block on CPU, using temp slot", file=sys.stderr)
|
||||
slot = 0 # Use first slot temporarily
|
||||
|
||||
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Set up context for chunked decode
|
||||
set_context(
|
||||
is_prefill=False, # Decode mode
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_len,
|
||||
is_chunked_prefill=True, # Use chunked attention
|
||||
offload_engine=self.kvcache_manager,
|
||||
chunked_seq=seq,
|
||||
)
|
||||
|
||||
# Run model forward pass
|
||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||
reset_context()
|
||||
|
||||
# Sample
|
||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
|
||||
|
||||
return token_ids
|
||||
|
||||
def _prepare_chunked_prefill(
|
||||
self,
|
||||
seq: Sequence,
|
||||
gpu_blocks: list[int],
|
||||
start_block_idx: int,
|
||||
end_block_idx: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Prepare inputs for a single chunk in chunked prefill.
|
||||
|
||||
Sets up context with is_chunked_prefill=True so attention layers
|
||||
know to load previous KV from CPU.
|
||||
"""
|
||||
# Calculate token range for this chunk
|
||||
start_token = start_block_idx * self.block_size
|
||||
end_token = min(end_block_idx * self.block_size, len(seq))
|
||||
|
||||
# Input tokens for this chunk
|
||||
input_ids = seq[start_token:end_token]
|
||||
positions = list(range(start_token, end_token))
|
||||
|
||||
# Slot mapping for storing KV cache
|
||||
slot_mapping = []
|
||||
for i, gpu_block_id in enumerate(gpu_blocks):
|
||||
block_idx = start_block_idx + i
|
||||
start = gpu_block_id * self.block_size
|
||||
if block_idx != seq.num_blocks - 1:
|
||||
end = start + self.block_size
|
||||
else:
|
||||
end = start + seq.last_block_num_tokens
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
|
||||
# Trim slot_mapping to match actual token count
|
||||
actual_tokens = end_token - start_token
|
||||
slot_mapping = slot_mapping[:actual_tokens]
|
||||
|
||||
# Convert to tensors
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
# Set up context for chunked prefill
|
||||
seqlen = actual_tokens
|
||||
cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
set_context(
|
||||
is_prefill=True,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=seqlen,
|
||||
max_seqlen_k=seqlen,
|
||||
slot_mapping=slot_mapping,
|
||||
is_chunked_prefill=True,
|
||||
offload_engine=self.kvcache_manager, # Pass manager for loading previous KV
|
||||
chunked_seq=seq, # Pass sequence for loading previous KV
|
||||
)
|
||||
|
||||
return input_ids, positions
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_cudagraph(self):
|
||||
config = self.config
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
from collections import deque
|
||||
from time import perf_counter_ns
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from nanovllm.config import Config
|
||||
from nanovllm.engine.sequence import Sequence, SequenceStatus
|
||||
from nanovllm.engine.block_manager import BlockManager
|
||||
from nanovllm.utils.observer import Observer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanovllm.kvcache import KVCacheManager
|
||||
|
||||
|
||||
class Scheduler:
|
||||
|
||||
def __init__(self, config: Config):
|
||||
def __init__(self, config: Config, kvcache_manager: "KVCacheManager"):
|
||||
self.max_num_seqs = config.max_num_seqs
|
||||
self.max_num_batched_tokens = config.max_num_batched_tokens
|
||||
self.eos = config.eos
|
||||
self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size)
|
||||
self.kvcache_manager = kvcache_manager
|
||||
self.waiting: deque[Sequence] = deque()
|
||||
self.running: deque[Sequence] = deque()
|
||||
|
||||
@@ -32,10 +35,10 @@ class Scheduler:
|
||||
if Observer.ttft_start == 0:
|
||||
Observer.ttft_start = perf_counter_ns()
|
||||
seq = self.waiting[0]
|
||||
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq):
|
||||
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.kvcache_manager.can_allocate(seq):
|
||||
break
|
||||
num_seqs += 1
|
||||
self.block_manager.allocate(seq)
|
||||
self.kvcache_manager.allocate(seq)
|
||||
num_batched_tokens += len(seq) - seq.num_cached_tokens
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
self.waiting.popleft()
|
||||
@@ -47,7 +50,7 @@ class Scheduler:
|
||||
# decode
|
||||
while self.running and num_seqs < self.max_num_seqs:
|
||||
seq = self.running.popleft()
|
||||
while not self.block_manager.can_append(seq):
|
||||
while not self.kvcache_manager.can_append(seq):
|
||||
if self.running:
|
||||
self.preempt(self.running.pop())
|
||||
else:
|
||||
@@ -55,7 +58,7 @@ class Scheduler:
|
||||
break
|
||||
else:
|
||||
num_seqs += 1
|
||||
self.block_manager.may_append(seq)
|
||||
self.kvcache_manager.may_append(seq)
|
||||
scheduled_seqs.append(seq)
|
||||
assert scheduled_seqs
|
||||
self.running.extendleft(reversed(scheduled_seqs))
|
||||
@@ -63,7 +66,7 @@ class Scheduler:
|
||||
|
||||
def preempt(self, seq: Sequence):
|
||||
seq.status = SequenceStatus.WAITING
|
||||
self.block_manager.deallocate(seq)
|
||||
self.kvcache_manager.deallocate(seq)
|
||||
self.waiting.appendleft(seq)
|
||||
|
||||
def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]:
|
||||
@@ -71,5 +74,5 @@ class Scheduler:
|
||||
seq.append_token(token_id)
|
||||
if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens:
|
||||
seq.status = SequenceStatus.FINISHED
|
||||
self.block_manager.deallocate(seq)
|
||||
self.kvcache_manager.deallocate(seq)
|
||||
self.running.remove(seq)
|
||||
|
||||
74
nanovllm/kvcache/__init__.py
Normal file
74
nanovllm/kvcache/__init__.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""
|
||||
KV Cache management module.
|
||||
|
||||
This module provides pluggable KV cache management strategies:
|
||||
- GPUOnlyManager: Pure GPU (default, current nano-vllm behavior)
|
||||
- HybridKVCacheManager: CPU offload with CUDA Graph support
|
||||
|
||||
Usage:
|
||||
from nanovllm.kvcache import create_kvcache_manager
|
||||
|
||||
manager = create_kvcache_manager(config)
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from nanovllm.kvcache.base_manager import KVCacheManager
|
||||
from nanovllm.kvcache.gpu_manager import GPUOnlyManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanovllm.config import Config
|
||||
|
||||
|
||||
def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
||||
"""
|
||||
Factory function to create the appropriate KV cache manager.
|
||||
|
||||
Decision logic:
|
||||
1. If enable_cpu_offload=False: use GPUOnlyManager
|
||||
2. If enable_cpu_offload=True but all blocks fit in GPU: use GPUOnlyManager
|
||||
3. If enable_cpu_offload=True and need CPU blocks: use HybridKVCacheManager
|
||||
|
||||
Args:
|
||||
config: Model configuration with offload settings
|
||||
|
||||
Returns:
|
||||
KVCacheManager instance
|
||||
"""
|
||||
if not getattr(config, 'enable_cpu_offload', False):
|
||||
# Default: pure GPU mode
|
||||
return GPUOnlyManager(
|
||||
num_blocks=config.num_kvcache_blocks,
|
||||
block_size=config.kvcache_block_size,
|
||||
)
|
||||
|
||||
# CPU offload is enabled
|
||||
num_gpu_blocks = config.num_gpu_kvcache_blocks
|
||||
num_cpu_blocks = config.num_cpu_kvcache_blocks
|
||||
|
||||
if num_cpu_blocks <= 0:
|
||||
# All blocks fit in GPU, use pure GPU mode
|
||||
return GPUOnlyManager(
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=config.kvcache_block_size,
|
||||
)
|
||||
|
||||
# Need CPU offload: use hybrid manager
|
||||
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
||||
from nanovllm.kvcache.policies import get_policy
|
||||
|
||||
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||
|
||||
return HybridKVCacheManager(
|
||||
num_gpu_slots=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=config.kvcache_block_size,
|
||||
policy=policy,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"KVCacheManager",
|
||||
"GPUOnlyManager",
|
||||
"create_kvcache_manager",
|
||||
]
|
||||
260
nanovllm/kvcache/base_manager.py
Normal file
260
nanovllm/kvcache/base_manager.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Abstract base class for KV cache managers.
|
||||
|
||||
This interface allows pluggable implementations:
|
||||
- GPUOnlyManager: Pure GPU (current nano-vllm behavior)
|
||||
- HybridKVCacheManager: CPU offload with CUDA Graph support
|
||||
- Future: Disk offload, distributed cache, etc.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
|
||||
|
||||
class KVCacheManager(ABC):
|
||||
"""
|
||||
Abstract base class for KV cache management strategies.
|
||||
|
||||
A KVCacheManager handles:
|
||||
1. Physical memory allocation (GPU and optionally CPU)
|
||||
2. Logical block management (allocation, deallocation, prefix caching)
|
||||
3. Data transfer between devices (for hybrid managers)
|
||||
4. Integration with CUDA graphs
|
||||
|
||||
Key design principles:
|
||||
- Sequences reference logical block IDs
|
||||
- Physical block IDs (GPU slots) may differ from logical IDs
|
||||
- CUDA Graph compatibility requires fixed tensor addresses
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def block_size(self) -> int:
|
||||
"""Number of tokens per block."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_free_blocks(self) -> int:
|
||||
"""Number of free logical blocks available for allocation."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate_cache(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""
|
||||
Allocate KV cache storage.
|
||||
|
||||
Called once during initialization to allocate GPU (and optionally CPU)
|
||||
memory for the KV cache.
|
||||
|
||||
Args:
|
||||
num_layers: Number of transformer layers
|
||||
num_kv_heads: Number of key-value heads per layer
|
||||
head_dim: Dimension per head
|
||||
dtype: Data type for cache (e.g., torch.float16)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get K and V cache tensors for a specific layer.
|
||||
|
||||
The returned tensors must be on GPU and have fixed addresses
|
||||
for CUDA Graph compatibility.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache) tensors
|
||||
Shape depends on implementation, typically:
|
||||
[num_blocks, block_size, kv_heads, head_dim]
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
"""
|
||||
Check if blocks can be allocated for a new sequence.
|
||||
|
||||
Called before allocate() to ensure sufficient resources.
|
||||
|
||||
Args:
|
||||
seq: Sequence to check
|
||||
|
||||
Returns:
|
||||
True if allocation is possible
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def allocate(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Allocate blocks for a sequence during prefill.
|
||||
|
||||
This method:
|
||||
1. Checks prefix cache for matching blocks
|
||||
2. Allocates new blocks as needed
|
||||
3. Updates seq.block_table with logical block IDs
|
||||
4. Updates seq.num_cached_tokens for prefix cache hits
|
||||
|
||||
Args:
|
||||
seq: Sequence to allocate blocks for
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def deallocate(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Release blocks for a finished sequence.
|
||||
|
||||
This method:
|
||||
1. Decrements reference counts
|
||||
2. Frees blocks with zero references
|
||||
3. Clears seq.block_table
|
||||
|
||||
Args:
|
||||
seq: Sequence whose blocks to release
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def can_append(self, seq: Sequence) -> bool:
|
||||
"""
|
||||
Check if a new block can be allocated for decode.
|
||||
|
||||
Called before may_append() to check if resources are available.
|
||||
|
||||
Args:
|
||||
seq: Sequence to check
|
||||
|
||||
Returns:
|
||||
True if append is possible (or no new block needed)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def may_append(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Potentially allocate a new block during decode.
|
||||
|
||||
Called after each decode step. If the current block is full,
|
||||
allocates a new block and updates seq.block_table.
|
||||
|
||||
Args:
|
||||
seq: Sequence that may need a new block
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def prepare_for_attention(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
is_prefill: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Prepare KV cache for attention computation.
|
||||
|
||||
For GPU-only managers: typically a no-op.
|
||||
For hybrid managers: ensures all needed blocks are on GPU,
|
||||
may trigger prefetching from CPU.
|
||||
|
||||
Called before attention computation. For decode with CUDA graphs,
|
||||
this should update gather_indices but not perform actual transfers
|
||||
(transfers happen inside the graph).
|
||||
|
||||
Args:
|
||||
seqs: Sequences that will be processed
|
||||
is_prefill: True for prefill phase, False for decode
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_gpu_block_tables(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Get GPU physical block tables for sequences.
|
||||
|
||||
For GPU-only managers: returns seq.block_table directly.
|
||||
For hybrid managers: returns GPU slot IDs (may differ from logical IDs).
|
||||
|
||||
The returned block tables are used to compute slot_mapping
|
||||
in ModelRunner.prepare_prefill/decode.
|
||||
|
||||
Args:
|
||||
seqs: Sequences to get block tables for
|
||||
|
||||
Returns:
|
||||
List of GPU block tables, one per sequence
|
||||
"""
|
||||
pass
|
||||
|
||||
def post_attention_cleanup(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
is_prefill: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Cleanup after attention computation.
|
||||
|
||||
Optional hook for managers to perform post-attention tasks:
|
||||
- Offloading cold blocks to CPU
|
||||
- Updating access statistics
|
||||
- etc.
|
||||
|
||||
Default implementation does nothing.
|
||||
|
||||
Args:
|
||||
seqs: Sequences that were processed
|
||||
is_prefill: True for prefill phase, False for decode
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_num_blocks_needed(self, num_tokens: int) -> int:
|
||||
"""
|
||||
Calculate number of blocks needed for given token count.
|
||||
|
||||
Args:
|
||||
num_tokens: Number of tokens
|
||||
|
||||
Returns:
|
||||
Number of blocks needed
|
||||
"""
|
||||
return (num_tokens + self.block_size - 1) // self.block_size
|
||||
|
||||
@staticmethod
|
||||
def compute_hash(token_ids: list, prefix: int = -1) -> int:
|
||||
"""
|
||||
Compute hash for prefix caching.
|
||||
|
||||
Uses xxhash for fast hashing. The hash includes the prefix hash
|
||||
to create a chain of hashes for multi-block sequences.
|
||||
|
||||
Args:
|
||||
token_ids: Token IDs in the block
|
||||
prefix: Hash of previous block, or -1 for first block
|
||||
|
||||
Returns:
|
||||
Hash value
|
||||
"""
|
||||
import xxhash
|
||||
import numpy as np
|
||||
|
||||
h = xxhash.xxh64()
|
||||
if prefix != -1:
|
||||
h.update(prefix.to_bytes(8, "little"))
|
||||
h.update(np.array(token_ids).tobytes())
|
||||
return h.intdigest()
|
||||
555
nanovllm/kvcache/chunked_attention.py
Normal file
555
nanovllm/kvcache/chunked_attention.py
Normal file
@@ -0,0 +1,555 @@
|
||||
"""
|
||||
Chunked attention implementation for CPU KV cache offloading.
|
||||
|
||||
This module implements flash attention with LSE (log-sum-exp) output,
|
||||
enabling proper online softmax merging for chunked prefill.
|
||||
|
||||
Key functions:
|
||||
- flash_attn_with_lse: Flash attention that returns output and LSE
|
||||
- merge_attention_outputs: Merge outputs from multiple KV chunks
|
||||
- chunked_prefill_attention: High-level interface for chunked attention
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
{
|
||||
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
||||
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
||||
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
||||
}
|
||||
)
|
||||
@triton.jit
|
||||
def _fwd_kernel_with_lse(
|
||||
Q,
|
||||
K,
|
||||
V,
|
||||
Out,
|
||||
Lse,
|
||||
TMP,
|
||||
softmax_scale,
|
||||
stride_qb,
|
||||
stride_qh,
|
||||
stride_qm,
|
||||
stride_kb,
|
||||
stride_kh,
|
||||
stride_kn,
|
||||
stride_vb,
|
||||
stride_vh,
|
||||
stride_vn,
|
||||
stride_ob,
|
||||
stride_oh,
|
||||
stride_om,
|
||||
nheads,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
seqlen_q_rounded,
|
||||
headdim,
|
||||
CACHE_KEY_SEQLEN_Q,
|
||||
CACHE_KEY_SEQLEN_K,
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
BLOCK_HEADDIM: tl.constexpr,
|
||||
EVEN_M: tl.constexpr,
|
||||
EVEN_N: tl.constexpr,
|
||||
EVEN_HEADDIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
):
|
||||
"""Flash attention forward kernel with LSE output for online softmax."""
|
||||
start_m = tl.program_id(0)
|
||||
off_hb = tl.program_id(1)
|
||||
off_b = off_hb // nheads
|
||||
off_h = off_hb % nheads
|
||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
offs_n = tl.arange(0, BLOCK_N)
|
||||
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
||||
|
||||
q_ptrs = (
|
||||
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
|
||||
)
|
||||
k_ptrs = (
|
||||
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
||||
)
|
||||
v_ptrs = (
|
||||
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
||||
)
|
||||
|
||||
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
|
||||
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
|
||||
|
||||
# Load Q
|
||||
if EVEN_M & EVEN_N:
|
||||
if EVEN_HEADDIM:
|
||||
q = tl.load(q_ptrs)
|
||||
else:
|
||||
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
||||
else:
|
||||
q = tl.load(
|
||||
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
|
||||
)
|
||||
|
||||
# Loop over k, v and update accumulator
|
||||
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
||||
for start_n in range(0, end_n, BLOCK_N):
|
||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
||||
|
||||
# Load K
|
||||
if EVEN_N & EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
||||
else:
|
||||
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kn,
|
||||
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
k = tl.load(
|
||||
k_ptrs + start_n * stride_kn,
|
||||
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
# Compute QK
|
||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
qk += tl.dot(q, tl.trans(k))
|
||||
|
||||
# Masking
|
||||
if not EVEN_N:
|
||||
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
||||
if IS_CAUSAL:
|
||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
||||
|
||||
# Online softmax
|
||||
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
|
||||
p = tl.exp(qk * softmax_scale - m_ij[:, None])
|
||||
l_ij = tl.sum(p, 1)
|
||||
|
||||
# Scale acc_o
|
||||
acc_o_scale = tl.exp(m_i - m_ij)
|
||||
tl.store(t_ptrs, acc_o_scale)
|
||||
acc_o_scale = tl.load(t_ptrs)
|
||||
acc_o = acc_o * acc_o_scale[:, None]
|
||||
|
||||
# Load V and update output
|
||||
if EVEN_N & EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn)
|
||||
else:
|
||||
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vn,
|
||||
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
||||
other=0.0,
|
||||
)
|
||||
else:
|
||||
v = tl.load(
|
||||
v_ptrs + start_n * stride_vn,
|
||||
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
p = p.to(v.dtype)
|
||||
acc_o += tl.dot(p, v)
|
||||
|
||||
# Update statistics
|
||||
m_i = m_ij
|
||||
l_i_new = tl.exp(lse_i - m_ij) + l_ij
|
||||
lse_i = m_ij + tl.log(l_i_new)
|
||||
|
||||
# Final scaling
|
||||
o_scale = tl.exp(m_i - lse_i)
|
||||
tl.store(t_ptrs, o_scale)
|
||||
o_scale = tl.load(t_ptrs)
|
||||
acc_o = acc_o * o_scale[:, None]
|
||||
|
||||
# Store LSE
|
||||
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
||||
tl.store(lse_ptrs, lse_i)
|
||||
|
||||
# Store output
|
||||
out_ptrs = (
|
||||
Out
|
||||
+ off_b * stride_ob
|
||||
+ off_h * stride_oh
|
||||
+ (offs_m[:, None] * stride_om + offs_d[None, :])
|
||||
)
|
||||
if EVEN_M:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(out_ptrs, acc_o)
|
||||
else:
|
||||
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
|
||||
else:
|
||||
if EVEN_HEADDIM:
|
||||
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
||||
else:
|
||||
tl.store(
|
||||
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
|
||||
)
|
||||
|
||||
|
||||
def flash_attn_with_lse(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Flash attention forward pass that returns both output and LSE.
|
||||
|
||||
Supports GQA (grouped query attention) where num_kv_heads < num_q_heads.
|
||||
|
||||
Args:
|
||||
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
|
||||
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
|
||||
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
|
||||
causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
|
||||
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
|
||||
"""
|
||||
# Ensure contiguous
|
||||
if not q.is_contiguous():
|
||||
q = q.contiguous()
|
||||
if not k.is_contiguous():
|
||||
k = k.contiguous()
|
||||
if not v.is_contiguous():
|
||||
v = v.contiguous()
|
||||
|
||||
batch, seqlen_q, nheads_q, headdim = q.shape
|
||||
_, seqlen_k, nheads_kv, _ = k.shape
|
||||
|
||||
assert k.shape == (batch, seqlen_k, nheads_kv, headdim)
|
||||
assert v.shape == (batch, seqlen_k, nheads_kv, headdim)
|
||||
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
|
||||
assert q.dtype == k.dtype == v.dtype
|
||||
|
||||
# Handle GQA by repeating K/V heads
|
||||
if nheads_kv != nheads_q:
|
||||
assert nheads_q % nheads_kv == 0, f"nheads_q ({nheads_q}) must be divisible by nheads_kv ({nheads_kv})"
|
||||
repeat_factor = nheads_q // nheads_kv
|
||||
# [batch, seqlen_k, nheads_kv, headdim] -> [batch, seqlen_k, nheads_q, headdim]
|
||||
k = k.repeat_interleave(repeat_factor, dim=2)
|
||||
v = v.repeat_interleave(repeat_factor, dim=2)
|
||||
nheads = nheads_q
|
||||
else:
|
||||
nheads = nheads_q
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = 1.0 / math.sqrt(headdim)
|
||||
|
||||
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128
|
||||
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
||||
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
|
||||
out = torch.empty_like(q)
|
||||
|
||||
BLOCK_HEADDIM = max(triton.next_power_of_2(headdim), 16)
|
||||
BLOCK = 128
|
||||
num_warps = 4 if headdim <= 64 else 8
|
||||
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
|
||||
|
||||
_fwd_kernel_with_lse[grid](
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out,
|
||||
lse,
|
||||
tmp,
|
||||
softmax_scale,
|
||||
q.stride(0),
|
||||
q.stride(2),
|
||||
q.stride(1),
|
||||
k.stride(0),
|
||||
k.stride(2),
|
||||
k.stride(1),
|
||||
v.stride(0),
|
||||
v.stride(2),
|
||||
v.stride(1),
|
||||
out.stride(0),
|
||||
out.stride(2),
|
||||
out.stride(1),
|
||||
nheads,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
seqlen_q_rounded,
|
||||
headdim,
|
||||
seqlen_q // 32,
|
||||
seqlen_k // 32,
|
||||
causal,
|
||||
BLOCK_HEADDIM,
|
||||
BLOCK_M=BLOCK,
|
||||
BLOCK_N=BLOCK,
|
||||
num_warps=num_warps,
|
||||
num_stages=1,
|
||||
)
|
||||
|
||||
# Trim LSE to actual seqlen_q
|
||||
lse = lse[:, :, :seqlen_q]
|
||||
|
||||
# Ensure output has same dtype as input
|
||||
out = out.to(q.dtype)
|
||||
|
||||
return out, lse
|
||||
|
||||
|
||||
def merge_attention_outputs(
|
||||
o1: torch.Tensor,
|
||||
lse1: torch.Tensor,
|
||||
o2: torch.Tensor,
|
||||
lse2: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Merge two attention outputs using online softmax.
|
||||
|
||||
This implements the online softmax merging formula:
|
||||
- m_new = max(lse1, lse2)
|
||||
- o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new))
|
||||
- lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new))
|
||||
|
||||
Args:
|
||||
o1: First output [batch, seqlen_q, nheads, headdim]
|
||||
lse1: First LSE [batch, nheads, seqlen_q]
|
||||
o2: Second output [batch, seqlen_q, nheads, headdim]
|
||||
lse2: Second LSE [batch, nheads, seqlen_q]
|
||||
|
||||
Returns:
|
||||
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
|
||||
lse_merged: Merged LSE [batch, nheads, seqlen_q]
|
||||
"""
|
||||
# lse shape: [batch, nheads, seqlen_q]
|
||||
# o shape: [batch, seqlen_q, nheads, headdim]
|
||||
|
||||
# Compute max for numerical stability
|
||||
max_lse = torch.maximum(lse1, lse2)
|
||||
|
||||
# Compute scaling factors
|
||||
# exp1, exp2 shape: [batch, nheads, seqlen_q]
|
||||
exp1 = torch.exp(lse1 - max_lse)
|
||||
exp2 = torch.exp(lse2 - max_lse)
|
||||
|
||||
# Reshape for broadcasting with output
|
||||
# [batch, nheads, seqlen_q] -> [batch, seqlen_q, nheads, 1]
|
||||
exp1_broad = exp1.transpose(1, 2).unsqueeze(-1)
|
||||
exp2_broad = exp2.transpose(1, 2).unsqueeze(-1)
|
||||
|
||||
# Merge outputs
|
||||
sum_exp = exp1_broad + exp2_broad
|
||||
o_merged = (o1 * exp1_broad + o2 * exp2_broad) / sum_exp
|
||||
|
||||
# Compute merged LSE
|
||||
lse_merged = max_lse + torch.log(exp1 + exp2)
|
||||
|
||||
# Ensure output has same dtype as input
|
||||
o_merged = o_merged.to(o1.dtype)
|
||||
|
||||
return o_merged, lse_merged
|
||||
|
||||
|
||||
def chunked_attention_varlen(
|
||||
q: torch.Tensor,
|
||||
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k_list: List[torch.Tensor],
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k_list: List[int],
|
||||
softmax_scale: Optional[float] = None,
|
||||
causal_mask_per_chunk: Optional[List[bool]] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with KV split across multiple chunks.
|
||||
|
||||
This is the core function for chunked prefill. It computes attention
|
||||
against each KV chunk and merges results using online softmax.
|
||||
|
||||
For causal attention with chunked KV:
|
||||
- First chunk (current tokens): Apply causal mask
|
||||
- Previous chunks: No causal mask (all previous tokens are valid context)
|
||||
|
||||
Args:
|
||||
q: Query tensor [total_q_tokens, nheads, headdim]
|
||||
kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim]
|
||||
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
||||
cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk
|
||||
max_seqlen_q: Maximum query sequence length
|
||||
max_seqlen_k_list: List of maximum key sequence lengths for each chunk
|
||||
softmax_scale: Scaling factor
|
||||
causal_mask_per_chunk: Whether to apply causal mask for each chunk
|
||||
|
||||
Returns:
|
||||
out: Output tensor [total_q_tokens, nheads, headdim]
|
||||
"""
|
||||
if len(kv_chunks) == 0:
|
||||
raise ValueError("Need at least one KV chunk")
|
||||
|
||||
nheads = q.shape[1]
|
||||
headdim = q.shape[2]
|
||||
batch = cu_seqlens_q.shape[0] - 1
|
||||
|
||||
if softmax_scale is None:
|
||||
softmax_scale = 1.0 / math.sqrt(headdim)
|
||||
|
||||
if causal_mask_per_chunk is None:
|
||||
# Default: causal for last chunk only
|
||||
causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True]
|
||||
|
||||
# Initialize accumulated output and LSE
|
||||
accumulated_o = None
|
||||
accumulated_lse = None
|
||||
|
||||
for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
|
||||
is_causal = causal_mask_per_chunk[chunk_idx]
|
||||
|
||||
# Reshape Q for batch processing
|
||||
# For varlen, we need to handle each sequence separately
|
||||
# For simplicity, assume single sequence (batch=1) for now
|
||||
q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim]
|
||||
|
||||
# Compute attention for this chunk
|
||||
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_chunk,
|
||||
v_chunk,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if accumulated_o is None:
|
||||
accumulated_o = chunk_o
|
||||
accumulated_lse = chunk_lse
|
||||
else:
|
||||
accumulated_o, accumulated_lse = merge_attention_outputs(
|
||||
accumulated_o, accumulated_lse,
|
||||
chunk_o, chunk_lse,
|
||||
)
|
||||
|
||||
# Remove batch dimension
|
||||
return accumulated_o.squeeze(0)
|
||||
|
||||
|
||||
class ChunkedPrefillState:
|
||||
"""
|
||||
State for tracking chunked prefill progress.
|
||||
|
||||
This class maintains the accumulated attention output and LSE
|
||||
across multiple prefill chunks.
|
||||
"""
|
||||
|
||||
def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device):
|
||||
self.num_layers = num_layers
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
# Per-layer accumulated outputs
|
||||
# Each entry: (accumulated_output, accumulated_lse) or None
|
||||
self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
|
||||
None for _ in range(num_layers)
|
||||
]
|
||||
|
||||
# Track which chunks have been processed
|
||||
self.processed_chunks: int = 0
|
||||
|
||||
def update_layer(
|
||||
self,
|
||||
layer_id: int,
|
||||
chunk_output: torch.Tensor,
|
||||
chunk_lse: torch.Tensor,
|
||||
):
|
||||
"""Update accumulated state for a layer with a new chunk's output."""
|
||||
if self.layer_states[layer_id] is None:
|
||||
self.layer_states[layer_id] = (chunk_output, chunk_lse)
|
||||
else:
|
||||
acc_o, acc_lse = self.layer_states[layer_id]
|
||||
merged_o, merged_lse = merge_attention_outputs(
|
||||
acc_o, acc_lse,
|
||||
chunk_output, chunk_lse,
|
||||
)
|
||||
self.layer_states[layer_id] = (merged_o, merged_lse)
|
||||
|
||||
def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]:
|
||||
"""Get the final accumulated output for a layer."""
|
||||
if self.layer_states[layer_id] is None:
|
||||
return None
|
||||
return self.layer_states[layer_id][0]
|
||||
|
||||
def clear(self):
|
||||
"""Clear all accumulated state."""
|
||||
self.layer_states = [None for _ in range(self.num_layers)]
|
||||
self.processed_chunks = 0
|
||||
|
||||
|
||||
# Test function
|
||||
def _test_chunked_attention():
|
||||
"""Test chunked attention correctness against full attention."""
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
torch.manual_seed(42)
|
||||
|
||||
batch, seqlen, nheads, headdim = 1, 1024, 32, 128
|
||||
|
||||
# Generate random Q, K, V
|
||||
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
|
||||
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
|
||||
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
# Full attention (reference)
|
||||
out_ref = flash_attn_func(q, k, v, causal=True)
|
||||
|
||||
# Chunked attention
|
||||
chunk_size = 256
|
||||
num_chunks = seqlen // chunk_size
|
||||
|
||||
accumulated_o = None
|
||||
accumulated_lse = None
|
||||
|
||||
for i in range(num_chunks):
|
||||
start = i * chunk_size
|
||||
end = (i + 1) * chunk_size
|
||||
|
||||
# Q for this chunk
|
||||
q_chunk = q[:, start:end, :, :]
|
||||
|
||||
# K, V up to current position (for causal)
|
||||
k_context = k[:, :end, :, :]
|
||||
v_context = v[:, :end, :, :]
|
||||
|
||||
# Compute attention
|
||||
chunk_o, chunk_lse = flash_attn_with_lse(
|
||||
q_chunk, k_context, v_context, causal=True
|
||||
)
|
||||
|
||||
if accumulated_o is None:
|
||||
accumulated_o = chunk_o
|
||||
accumulated_lse = chunk_lse
|
||||
else:
|
||||
# For chunked prefill, we need to concatenate outputs, not merge
|
||||
# Because each chunk's Q attends to different K positions
|
||||
accumulated_o = torch.cat([accumulated_o, chunk_o], dim=1)
|
||||
|
||||
# Compare
|
||||
max_diff = (out_ref - accumulated_o).abs().max().item()
|
||||
print(f"Max difference: {max_diff}")
|
||||
assert max_diff < 1e-2, f"Chunked attention differs from reference: {max_diff}"
|
||||
print("Test passed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_test_chunked_attention()
|
||||
262
nanovllm/kvcache/gpu_manager.py
Normal file
262
nanovllm/kvcache/gpu_manager.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
GPU-only KV cache manager.
|
||||
|
||||
This is the default manager when CPU offload is disabled.
|
||||
Refactored from the original block_manager.py to implement
|
||||
the KVCacheManager interface.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from typing import List, Tuple, Dict, Optional
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.kvcache.base_manager import KVCacheManager
|
||||
|
||||
|
||||
class Block:
|
||||
"""Physical block in GPU memory."""
|
||||
|
||||
def __init__(self, block_id: int):
|
||||
self.block_id = block_id
|
||||
self.ref_count = 0
|
||||
self.hash = -1
|
||||
self.token_ids: List[int] = []
|
||||
|
||||
def update(self, hash: int, token_ids: List[int]):
|
||||
self.hash = hash
|
||||
self.token_ids = token_ids
|
||||
|
||||
def reset(self):
|
||||
self.ref_count = 1
|
||||
self.hash = -1
|
||||
self.token_ids = []
|
||||
|
||||
|
||||
class GPUOnlyManager(KVCacheManager):
|
||||
"""
|
||||
Pure GPU KV cache manager.
|
||||
|
||||
This is the default implementation when enable_cpu_offload=False.
|
||||
All KV cache resides in GPU memory.
|
||||
|
||||
Features:
|
||||
- Paged attention with configurable block size
|
||||
- Prefix caching via xxhash
|
||||
- Reference counting for block sharing
|
||||
|
||||
This manager is fully compatible with CUDA graphs since
|
||||
all data stays on GPU at fixed addresses.
|
||||
"""
|
||||
|
||||
def __init__(self, num_blocks: int, block_size: int):
|
||||
"""
|
||||
Initialize GPU-only manager.
|
||||
|
||||
Args:
|
||||
num_blocks: Total number of blocks to manage
|
||||
block_size: Tokens per block (default 256)
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self._num_blocks = num_blocks
|
||||
|
||||
# Block metadata
|
||||
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
|
||||
|
||||
# Prefix cache: hash -> block_id
|
||||
self.hash_to_block_id: Dict[int, int] = {}
|
||||
|
||||
# Free/used tracking
|
||||
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
||||
self.used_block_ids: set[int] = set()
|
||||
|
||||
# KV cache tensors (set by allocate_cache)
|
||||
self.kv_cache: Optional[Tensor] = None
|
||||
self.num_layers: int = 0
|
||||
self.num_kv_heads: int = 0
|
||||
self.head_dim: int = 0
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block_size
|
||||
|
||||
@property
|
||||
def num_free_blocks(self) -> int:
|
||||
return len(self.free_block_ids)
|
||||
|
||||
def allocate_cache(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Allocate GPU KV cache tensor."""
|
||||
self.num_layers = num_layers
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
|
||||
# Shape: [2, num_layers, num_blocks, block_size, kv_heads, head_dim]
|
||||
# 2 for K and V
|
||||
self.kv_cache = torch.empty(
|
||||
2, num_layers, self._num_blocks, self._block_size,
|
||||
num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
"""Get K/V cache for a layer."""
|
||||
assert self.kv_cache is not None, "Cache not allocated"
|
||||
return self.kv_cache[0, layer_id], self.kv_cache[1, layer_id]
|
||||
|
||||
def _allocate_block(self, block_id: int) -> Block:
|
||||
"""Internal: allocate a specific block."""
|
||||
block = self.blocks[block_id]
|
||||
assert block.ref_count == 0, f"Block {block_id} is not free"
|
||||
block.reset()
|
||||
self.free_block_ids.remove(block_id)
|
||||
self.used_block_ids.add(block_id)
|
||||
return block
|
||||
|
||||
def _deallocate_block(self, block_id: int) -> None:
|
||||
"""Internal: deallocate a block."""
|
||||
assert self.blocks[block_id].ref_count == 0
|
||||
self.used_block_ids.remove(block_id)
|
||||
self.free_block_ids.append(block_id)
|
||||
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
"""Check if we have enough blocks for the sequence."""
|
||||
return len(self.free_block_ids) >= seq.num_blocks
|
||||
|
||||
def allocate(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Allocate blocks for a sequence during prefill.
|
||||
|
||||
Implements prefix caching: if a block's content matches
|
||||
a previously cached block, reuse it instead of allocating new.
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks allocated"
|
||||
|
||||
h = -1 # Hash chain
|
||||
cache_miss = False
|
||||
|
||||
for i in range(seq.num_blocks):
|
||||
token_ids = seq.block(i)
|
||||
|
||||
# Only compute hash for full blocks
|
||||
if len(token_ids) == self._block_size:
|
||||
h = self.compute_hash(token_ids, h)
|
||||
else:
|
||||
h = -1
|
||||
|
||||
# Try prefix cache lookup
|
||||
block_id = self.hash_to_block_id.get(h, -1)
|
||||
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
||||
cache_miss = True
|
||||
|
||||
if cache_miss:
|
||||
# Cache miss: allocate new block
|
||||
block_id = self.free_block_ids[0]
|
||||
block = self._allocate_block(block_id)
|
||||
else:
|
||||
# Cache hit: reuse existing block
|
||||
seq.num_cached_tokens += self._block_size
|
||||
if block_id in self.used_block_ids:
|
||||
# Block is in use, increment ref count
|
||||
block = self.blocks[block_id]
|
||||
block.ref_count += 1
|
||||
else:
|
||||
# Block was freed but hash still valid
|
||||
block = self._allocate_block(block_id)
|
||||
|
||||
# Update hash mapping for full blocks
|
||||
if h != -1:
|
||||
block.update(h, token_ids)
|
||||
self.hash_to_block_id[h] = block_id
|
||||
|
||||
seq.block_table.append(block_id)
|
||||
|
||||
def deallocate(self, seq: Sequence) -> None:
|
||||
"""Release all blocks for a sequence."""
|
||||
for block_id in reversed(seq.block_table):
|
||||
block = self.blocks[block_id]
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
self._deallocate_block(block_id)
|
||||
|
||||
seq.num_cached_tokens = 0
|
||||
seq.block_table.clear()
|
||||
|
||||
def can_append(self, seq: Sequence) -> bool:
|
||||
"""Check if we can append a token (may need new block)."""
|
||||
# Need new block only if current position is at block boundary
|
||||
need_new_block = (len(seq) % self._block_size == 1)
|
||||
return len(self.free_block_ids) >= int(need_new_block)
|
||||
|
||||
def may_append(self, seq: Sequence) -> None:
|
||||
"""Handle potential new block allocation during decode."""
|
||||
block_table = seq.block_table
|
||||
last_block = self.blocks[block_table[-1]]
|
||||
|
||||
seq_len = len(seq)
|
||||
pos_in_block = seq_len % self._block_size
|
||||
|
||||
if pos_in_block == 1:
|
||||
# Just crossed into new block, need to allocate
|
||||
assert last_block.hash != -1, "Previous block should be complete"
|
||||
block_id = self.free_block_ids[0]
|
||||
self._allocate_block(block_id)
|
||||
block_table.append(block_id)
|
||||
|
||||
elif pos_in_block == 0:
|
||||
# Just filled a block, compute hash for prefix cache
|
||||
assert last_block.hash == -1, "Block should not have hash yet"
|
||||
token_ids = seq.block(seq.num_blocks - 1)
|
||||
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||
h = self.compute_hash(token_ids, prefix)
|
||||
last_block.update(h, token_ids)
|
||||
self.hash_to_block_id[h] = last_block.block_id
|
||||
|
||||
else:
|
||||
# Middle of block, nothing to do
|
||||
assert last_block.hash == -1
|
||||
|
||||
def prepare_for_attention(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
is_prefill: bool,
|
||||
) -> None:
|
||||
"""
|
||||
No-op for GPU-only manager.
|
||||
|
||||
All blocks are already on GPU, no preparation needed.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_gpu_block_tables(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Return block tables directly (logical = physical for GPU-only).
|
||||
"""
|
||||
return [list(seq.block_table) for seq in seqs]
|
||||
|
||||
def post_attention_cleanup(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
is_prefill: bool,
|
||||
) -> None:
|
||||
"""No-op for GPU-only manager."""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"GPUOnlyManager("
|
||||
f"num_blocks={self._num_blocks}, "
|
||||
f"block_size={self._block_size}, "
|
||||
f"free={len(self.free_block_ids)}, "
|
||||
f"used={len(self.used_block_ids)}"
|
||||
f")"
|
||||
)
|
||||
906
nanovllm/kvcache/hybrid_manager.py
Normal file
906
nanovllm/kvcache/hybrid_manager.py
Normal file
@@ -0,0 +1,906 @@
|
||||
"""
|
||||
Hybrid CPU-GPU KV cache manager with CUDA Graph support.
|
||||
|
||||
Key design for CUDA Graph compatibility:
|
||||
1. GPU buffer has fixed addresses (allocated once)
|
||||
2. CPU pool has fixed addresses (pinned memory)
|
||||
3. gather_indices tensor has fixed address, variable content
|
||||
4. H2D transfer uses gathered_copy kernel inside CUDA graphs
|
||||
5. Graph replay only needs index updates (tiny overhead)
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum, auto
|
||||
from typing import List, Tuple, Dict, Set, Optional
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.kvcache.base_manager import KVCacheManager
|
||||
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
|
||||
from nanovllm.kvcache.policies.lru_policy import LRUPolicy
|
||||
|
||||
|
||||
class BlockLocation(Enum):
|
||||
"""Where a logical block's data currently resides."""
|
||||
GPU = auto()
|
||||
CPU = auto()
|
||||
INVALID = auto() # Not yet written / deallocated
|
||||
|
||||
|
||||
@dataclass
|
||||
class LogicalBlock:
|
||||
"""
|
||||
Logical block that can be mapped to GPU or CPU physical storage.
|
||||
|
||||
Sequences reference logical blocks. Physical blocks are the actual
|
||||
storage locations (GPU slots or CPU blocks).
|
||||
"""
|
||||
logical_id: int
|
||||
location: BlockLocation = BlockLocation.INVALID
|
||||
gpu_slot: int = -1 # GPU buffer slot ID (if on GPU)
|
||||
cpu_block_id: int = -1 # CPU pool block ID (if on CPU)
|
||||
ref_count: int = 0
|
||||
hash: int = -1
|
||||
token_ids: List[int] = field(default_factory=list)
|
||||
|
||||
def reset(self):
|
||||
self.location = BlockLocation.INVALID
|
||||
self.gpu_slot = -1
|
||||
self.cpu_block_id = -1
|
||||
self.ref_count = 0
|
||||
self.hash = -1
|
||||
self.token_ids = []
|
||||
|
||||
|
||||
class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
Hybrid CPU-GPU KV cache manager with CUDA Graph support.
|
||||
|
||||
Architecture:
|
||||
- GPU buffer: Fixed-size working set (num_gpu_slots)
|
||||
- CPU pool: Overflow storage (num_cpu_blocks)
|
||||
- Logical blocks: What sequences reference (num_gpu_slots + num_cpu_blocks)
|
||||
|
||||
CUDA Graph compatibility:
|
||||
- All tensor addresses fixed at init time
|
||||
- prepare_for_attention() updates gather_indices (outside graph)
|
||||
- gathered_h2d_layer() executes transfer (inside graph)
|
||||
|
||||
Strategy:
|
||||
1. New KV data written to GPU slots
|
||||
2. Cold blocks evicted to CPU using configurable policy
|
||||
3. Needed blocks prefetched back to GPU before attention
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_gpu_slots: int,
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
policy: Optional[EvictionPolicy] = None,
|
||||
):
|
||||
"""
|
||||
Initialize hybrid manager.
|
||||
|
||||
Args:
|
||||
num_gpu_slots: Number of GPU buffer slots (working set)
|
||||
num_cpu_blocks: Number of CPU pool blocks (overflow)
|
||||
block_size: Tokens per block
|
||||
policy: Eviction policy (default: LRU)
|
||||
"""
|
||||
self._block_size = block_size
|
||||
self.num_gpu_slots = num_gpu_slots
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.total_blocks = num_gpu_slots + num_cpu_blocks
|
||||
|
||||
# Eviction policy
|
||||
self.policy = policy or LRUPolicy()
|
||||
|
||||
# Logical blocks (what sequences reference)
|
||||
self.logical_blocks: List[LogicalBlock] = [
|
||||
LogicalBlock(i) for i in range(self.total_blocks)
|
||||
]
|
||||
self.free_logical_ids: deque[int] = deque(range(self.total_blocks))
|
||||
|
||||
# GPU slot management (slots are fixed, mapping is variable)
|
||||
self.free_gpu_slots: deque[int] = deque(range(num_gpu_slots))
|
||||
self.gpu_slot_to_logical: Dict[int, int] = {} # gpu_slot -> logical_id
|
||||
|
||||
# CPU block management
|
||||
self.free_cpu_blocks: deque[int] = deque(range(num_cpu_blocks))
|
||||
self.cpu_block_to_logical: Dict[int, int] = {} # cpu_block -> logical_id
|
||||
|
||||
# Prefix cache (uses logical block IDs)
|
||||
self.hash_to_logical_id: Dict[int, int] = {}
|
||||
|
||||
# Step counter for policy
|
||||
self.current_step = 0
|
||||
|
||||
# Offload engine (set by allocate_cache)
|
||||
self.offload_engine: Optional[OffloadEngine] = None
|
||||
|
||||
# Track blocks pending GPU load (for decode graph)
|
||||
self.pending_gpu_loads: Set[int] = set() # logical_ids
|
||||
|
||||
# Track blocks that have been prefilled (KV written) for chunked prefill
|
||||
self.prefilled_blocks: Set[int] = set() # logical_ids
|
||||
|
||||
@property
|
||||
def block_size(self) -> int:
|
||||
return self._block_size
|
||||
|
||||
@property
|
||||
def num_free_blocks(self) -> int:
|
||||
return len(self.free_logical_ids)
|
||||
|
||||
def allocate_cache(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
"""Initialize the offload engine with actual cache storage."""
|
||||
self.offload_engine = OffloadEngine(
|
||||
num_layers=num_layers,
|
||||
num_gpu_blocks=self.num_gpu_slots,
|
||||
num_cpu_blocks=self.num_cpu_blocks,
|
||||
block_size=self._block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_dim=head_dim,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
"""Get GPU K/V cache tensors for a layer."""
|
||||
assert self.offload_engine is not None
|
||||
return self.offload_engine.get_layer_cache(layer_id)
|
||||
|
||||
def _allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> int:
|
||||
"""
|
||||
Get a free GPU slot, evicting if necessary.
|
||||
|
||||
Args:
|
||||
protected_logical_ids: Logical block IDs that cannot be evicted
|
||||
|
||||
Returns:
|
||||
GPU slot ID
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no GPU slot is available
|
||||
"""
|
||||
if self.free_gpu_slots:
|
||||
return self.free_gpu_slots.popleft()
|
||||
|
||||
# Need to evict - find victim using policy
|
||||
return self._evict_to_cpu(protected_logical_ids)
|
||||
|
||||
def _try_allocate_gpu_slot(self, protected_logical_ids: Optional[Set[int]] = None) -> Optional[int]:
|
||||
"""
|
||||
Try to get a free GPU slot, evicting if necessary.
|
||||
|
||||
Unlike _allocate_gpu_slot(), returns None instead of raising if no eviction possible.
|
||||
|
||||
Args:
|
||||
protected_logical_ids: Logical block IDs that cannot be evicted
|
||||
|
||||
Returns:
|
||||
GPU slot ID, or None if no slot available
|
||||
"""
|
||||
if self.free_gpu_slots:
|
||||
return self.free_gpu_slots.popleft()
|
||||
|
||||
# Check if we can evict
|
||||
protected = protected_logical_ids or set()
|
||||
for gpu_slot, logical_id in self.gpu_slot_to_logical.items():
|
||||
if logical_id not in protected:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.ref_count > 0:
|
||||
# Found evictable block
|
||||
return self._evict_to_cpu(protected_logical_ids)
|
||||
|
||||
# No evictable blocks
|
||||
return None
|
||||
|
||||
def _evict_to_cpu(self, protected_logical_ids: Optional[Set[int]] = None) -> int:
|
||||
"""
|
||||
Evict a GPU block to CPU to make room.
|
||||
|
||||
Args:
|
||||
protected_logical_ids: Logical block IDs that cannot be evicted
|
||||
|
||||
Returns:
|
||||
The freed GPU slot ID
|
||||
"""
|
||||
protected = protected_logical_ids or set()
|
||||
|
||||
# Find candidates (blocks currently on GPU with ref_count > 0, excluding protected)
|
||||
candidates: Set[int] = set()
|
||||
for gpu_slot, logical_id in self.gpu_slot_to_logical.items():
|
||||
if logical_id in protected:
|
||||
continue # Skip protected blocks
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.ref_count > 0: # Only evict blocks still in use
|
||||
candidates.add(gpu_slot)
|
||||
|
||||
if not candidates:
|
||||
raise RuntimeError(
|
||||
f"No GPU slots available for eviction. "
|
||||
f"GPU slots: {self.num_gpu_slots}, protected: {len(protected)}, "
|
||||
f"need more GPU memory or reduce sequence length"
|
||||
)
|
||||
|
||||
# Use policy to select victim
|
||||
victim_gpu_slot = self.policy.select_victim(candidates)
|
||||
logical_id = self.gpu_slot_to_logical[victim_gpu_slot]
|
||||
block = self.logical_blocks[logical_id]
|
||||
|
||||
# Allocate CPU block
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError("Both GPU and CPU are full")
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
|
||||
# Async offload GPU -> CPU
|
||||
self.offload_engine.offload_block_async(
|
||||
layer_id=0, # TODO: handle per-layer offloading
|
||||
gpu_block_id=victim_gpu_slot,
|
||||
cpu_block_id=cpu_block_id,
|
||||
)
|
||||
|
||||
# Update mappings
|
||||
del self.gpu_slot_to_logical[victim_gpu_slot]
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
|
||||
block.location = BlockLocation.CPU
|
||||
block.gpu_slot = -1
|
||||
block.cpu_block_id = cpu_block_id
|
||||
|
||||
# Notify policy
|
||||
self.policy.on_block_evicted(victim_gpu_slot)
|
||||
|
||||
return victim_gpu_slot
|
||||
|
||||
def _ensure_on_gpu(
|
||||
self,
|
||||
logical_id: int,
|
||||
protected_logical_ids: Optional[Set[int]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Ensure a logical block is on GPU.
|
||||
|
||||
Args:
|
||||
logical_id: Logical block ID
|
||||
protected_logical_ids: Logical block IDs that cannot be evicted
|
||||
|
||||
Returns:
|
||||
GPU slot ID where the block is/will be
|
||||
"""
|
||||
block = self.logical_blocks[logical_id]
|
||||
|
||||
if block.location == BlockLocation.GPU:
|
||||
# Already on GPU, update policy
|
||||
self.policy.on_block_access(block.gpu_slot, self.current_step)
|
||||
return block.gpu_slot
|
||||
|
||||
if block.location == BlockLocation.CPU:
|
||||
# Need to prefetch from CPU
|
||||
gpu_slot = self._allocate_gpu_slot(protected_logical_ids)
|
||||
|
||||
# Async prefetch CPU -> GPU
|
||||
self.offload_engine.prefetch_block_async(
|
||||
layer_id=0, # TODO: handle per-layer
|
||||
cpu_block_id=block.cpu_block_id,
|
||||
gpu_block_id=gpu_slot,
|
||||
)
|
||||
|
||||
# Update mappings
|
||||
self.free_cpu_blocks.append(block.cpu_block_id)
|
||||
del self.cpu_block_to_logical[block.cpu_block_id]
|
||||
|
||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
||||
|
||||
block.location = BlockLocation.GPU
|
||||
block.gpu_slot = gpu_slot
|
||||
block.cpu_block_id = -1
|
||||
|
||||
# Notify policy
|
||||
self.policy.on_block_prefetched(gpu_slot, self.current_step)
|
||||
|
||||
return gpu_slot
|
||||
|
||||
raise RuntimeError(f"Block {logical_id} is in invalid state")
|
||||
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
"""Check if we can allocate blocks for a new sequence."""
|
||||
return len(self.free_logical_ids) >= seq.num_blocks
|
||||
|
||||
def allocate(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Allocate logical blocks for prefill.
|
||||
|
||||
New blocks are allocated on GPU when possible. If GPU is full and all
|
||||
GPU blocks belong to this sequence (can't evict), remaining blocks
|
||||
are allocated to CPU for chunked prefill.
|
||||
"""
|
||||
assert not seq.block_table, "Sequence already has blocks"
|
||||
|
||||
h = -1
|
||||
cache_miss = False
|
||||
|
||||
# Track blocks allocated for this sequence to protect them from eviction
|
||||
allocated_for_seq: Set[int] = set()
|
||||
|
||||
for i in range(seq.num_blocks):
|
||||
token_ids = seq.block(i)
|
||||
|
||||
# Hash for full blocks only
|
||||
if len(token_ids) == self._block_size:
|
||||
h = self.compute_hash(token_ids, h)
|
||||
else:
|
||||
h = -1
|
||||
|
||||
# Check prefix cache
|
||||
cached_logical_id = self.hash_to_logical_id.get(h, -1)
|
||||
if cached_logical_id != -1:
|
||||
cached_block = self.logical_blocks[cached_logical_id]
|
||||
if cached_block.token_ids == token_ids and cached_block.ref_count > 0:
|
||||
# Cache hit
|
||||
cached_block.ref_count += 1
|
||||
seq.num_cached_tokens += self._block_size
|
||||
seq.block_table.append(cached_logical_id)
|
||||
allocated_for_seq.add(cached_logical_id)
|
||||
|
||||
# Ensure block is on GPU (protect already allocated blocks)
|
||||
if cached_block.location == BlockLocation.CPU:
|
||||
self._ensure_on_gpu(cached_logical_id, allocated_for_seq)
|
||||
|
||||
continue
|
||||
|
||||
cache_miss = True
|
||||
|
||||
# Allocate new logical block
|
||||
logical_id = self.free_logical_ids.popleft()
|
||||
block = self.logical_blocks[logical_id]
|
||||
block.ref_count = 1
|
||||
block.hash = h
|
||||
block.token_ids = token_ids.copy() if len(token_ids) == self._block_size else []
|
||||
|
||||
# Try to allocate GPU slot
|
||||
gpu_slot = self._try_allocate_gpu_slot(allocated_for_seq)
|
||||
if gpu_slot is not None:
|
||||
# Got GPU slot
|
||||
block.location = BlockLocation.GPU
|
||||
block.gpu_slot = gpu_slot
|
||||
block.cpu_block_id = -1
|
||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
||||
else:
|
||||
# GPU full and can't evict (all protected) - allocate to CPU
|
||||
# This block will be written via chunked prefill
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError(
|
||||
f"Both GPU and CPU are full. Need {seq.num_blocks} blocks, "
|
||||
f"GPU has {self.num_gpu_slots}, CPU has {self.num_cpu_blocks}"
|
||||
)
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
block.location = BlockLocation.CPU
|
||||
block.gpu_slot = -1
|
||||
block.cpu_block_id = cpu_block_id
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
|
||||
allocated_for_seq.add(logical_id)
|
||||
|
||||
# Update prefix cache
|
||||
if h != -1:
|
||||
self.hash_to_logical_id[h] = logical_id
|
||||
|
||||
# Notify policy
|
||||
self.policy.on_block_allocated(gpu_slot, self.current_step)
|
||||
|
||||
seq.block_table.append(logical_id)
|
||||
|
||||
def deallocate(self, seq: Sequence) -> None:
|
||||
"""Release all blocks for a sequence."""
|
||||
for logical_id in reversed(seq.block_table):
|
||||
block = self.logical_blocks[logical_id]
|
||||
block.ref_count -= 1
|
||||
|
||||
if block.ref_count == 0:
|
||||
# Free physical block
|
||||
if block.location == BlockLocation.GPU:
|
||||
self.free_gpu_slots.append(block.gpu_slot)
|
||||
del self.gpu_slot_to_logical[block.gpu_slot]
|
||||
self.policy.on_block_deallocated(block.gpu_slot)
|
||||
elif block.location == BlockLocation.CPU:
|
||||
self.free_cpu_blocks.append(block.cpu_block_id)
|
||||
del self.cpu_block_to_logical[block.cpu_block_id]
|
||||
|
||||
# Free logical block
|
||||
block.reset()
|
||||
self.free_logical_ids.append(logical_id)
|
||||
|
||||
# Remove from prefilled tracking
|
||||
self.prefilled_blocks.discard(logical_id)
|
||||
|
||||
seq.num_cached_tokens = 0
|
||||
seq.block_table.clear()
|
||||
|
||||
def can_append(self, seq: Sequence) -> bool:
|
||||
"""Check if we can append a token."""
|
||||
need_new_block = (len(seq) % self._block_size == 1)
|
||||
return len(self.free_logical_ids) >= int(need_new_block)
|
||||
|
||||
def may_append(self, seq: Sequence) -> None:
|
||||
"""Handle potential new block allocation during decode."""
|
||||
block_table = seq.block_table
|
||||
last_logical_id = block_table[-1]
|
||||
last_block = self.logical_blocks[last_logical_id]
|
||||
|
||||
seq_len = len(seq)
|
||||
pos_in_block = seq_len % self._block_size
|
||||
|
||||
if pos_in_block == 1:
|
||||
# Need new block
|
||||
assert last_block.hash != -1
|
||||
|
||||
logical_id = self.free_logical_ids.popleft()
|
||||
block = self.logical_blocks[logical_id]
|
||||
block.ref_count = 1
|
||||
block.hash = -1
|
||||
block.token_ids = []
|
||||
|
||||
# New decode blocks go to GPU
|
||||
gpu_slot = self._allocate_gpu_slot()
|
||||
block.location = BlockLocation.GPU
|
||||
block.gpu_slot = gpu_slot
|
||||
|
||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
||||
self.policy.on_block_allocated(gpu_slot, self.current_step)
|
||||
|
||||
block_table.append(logical_id)
|
||||
|
||||
elif pos_in_block == 0:
|
||||
# Block is full, update hash for prefix cache
|
||||
assert last_block.hash == -1
|
||||
token_ids = seq.block(seq.num_blocks - 1)
|
||||
prefix_hash = (
|
||||
self.logical_blocks[block_table[-2]].hash
|
||||
if len(block_table) > 1 else -1
|
||||
)
|
||||
h = self.compute_hash(token_ids, prefix_hash)
|
||||
last_block.hash = h
|
||||
last_block.token_ids = token_ids.copy()
|
||||
self.hash_to_logical_id[h] = last_logical_id
|
||||
|
||||
def prepare_for_attention(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
is_prefill: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Prepare KV cache for attention computation.
|
||||
|
||||
For prefill: async prefetch blocks from CPU to GPU.
|
||||
For decode: update gather_indices for CUDA graph.
|
||||
"""
|
||||
self.current_step += 1
|
||||
|
||||
# Collect all needed logical blocks
|
||||
needed_logical_ids: Set[int] = set()
|
||||
for seq in seqs:
|
||||
needed_logical_ids.update(seq.block_table)
|
||||
|
||||
if is_prefill:
|
||||
# Prefill: ensure all blocks on GPU (async prefetch)
|
||||
# Pass needed_logical_ids as protected to prevent evicting blocks we need
|
||||
for logical_id in needed_logical_ids:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
self._ensure_on_gpu(logical_id, needed_logical_ids)
|
||||
|
||||
# Wait for all prefetches to complete
|
||||
self.offload_engine.wait_all_transfers()
|
||||
|
||||
else:
|
||||
# Decode: Check if we need chunked decode
|
||||
cpu_blocks_count = sum(
|
||||
1 for lid in needed_logical_ids
|
||||
if self.logical_blocks[lid].location == BlockLocation.CPU
|
||||
)
|
||||
|
||||
if cpu_blocks_count > self.num_gpu_slots:
|
||||
# Too many blocks on CPU - will use chunked decode
|
||||
# Don't try to load all blocks now
|
||||
return
|
||||
|
||||
# Standard decode: prepare gather_indices for CUDA graph
|
||||
# Identify blocks needing transfer
|
||||
self.pending_gpu_loads.clear()
|
||||
mappings_per_layer: List[List[Tuple[int, int]]] = [
|
||||
[] for _ in range(self.offload_engine.num_layers)
|
||||
]
|
||||
|
||||
for logical_id in needed_logical_ids:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
# Allocate GPU slot (protect needed blocks from eviction)
|
||||
gpu_slot = self._allocate_gpu_slot(needed_logical_ids)
|
||||
|
||||
# Record mapping for each layer
|
||||
for layer_id in range(self.offload_engine.num_layers):
|
||||
mappings_per_layer[layer_id].append(
|
||||
(block.cpu_block_id, gpu_slot)
|
||||
)
|
||||
|
||||
# Update block state
|
||||
self.free_cpu_blocks.append(block.cpu_block_id)
|
||||
del self.cpu_block_to_logical[block.cpu_block_id]
|
||||
|
||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
||||
block.location = BlockLocation.GPU
|
||||
block.gpu_slot = gpu_slot
|
||||
block.cpu_block_id = -1
|
||||
|
||||
self.pending_gpu_loads.add(logical_id)
|
||||
self.policy.on_block_prefetched(gpu_slot, self.current_step)
|
||||
|
||||
elif block.location == BlockLocation.GPU:
|
||||
self.policy.on_block_access(block.gpu_slot, self.current_step)
|
||||
|
||||
# Update gather indices (outside graph)
|
||||
self.offload_engine.update_gather_indices_all_layers(mappings_per_layer)
|
||||
self.offload_engine.sync_indices()
|
||||
|
||||
def needs_chunked_decode(self, seq: Sequence) -> bool:
|
||||
"""
|
||||
Check if sequence needs chunked decode.
|
||||
|
||||
Returns True if there are blocks on CPU and total blocks exceed GPU capacity.
|
||||
"""
|
||||
cpu_blocks = 0
|
||||
for logical_id in seq.block_table:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_blocks += 1
|
||||
return cpu_blocks > 0 and len(seq.block_table) > self.num_gpu_slots
|
||||
|
||||
def load_all_kv_for_layer(
|
||||
self,
|
||||
seq: Sequence,
|
||||
layer_id: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Load ALL KV for a sequence from both GPU and CPU for a layer.
|
||||
|
||||
Used during chunked decode to compute full attention.
|
||||
|
||||
Returns:
|
||||
(k, v) tensors with shape [1, total_tokens, kv_heads, head_dim]
|
||||
"""
|
||||
k_chunks = []
|
||||
v_chunks = []
|
||||
|
||||
for logical_id in seq.block_table:
|
||||
block = self.logical_blocks[logical_id]
|
||||
|
||||
if block.location == BlockLocation.GPU:
|
||||
# Get from GPU cache
|
||||
k, v = self.offload_engine.get_layer_cache(layer_id)
|
||||
# k, v shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
k_block = k[block.gpu_slot] # [block_size, kv_heads, head_dim]
|
||||
v_block = v[block.gpu_slot]
|
||||
k_chunks.append(k_block)
|
||||
v_chunks.append(v_block)
|
||||
|
||||
elif block.location == BlockLocation.CPU:
|
||||
# Get from CPU cache
|
||||
k_block, v_block = self.offload_engine.get_cpu_block(layer_id, block.cpu_block_id)
|
||||
# Already [block_size, kv_heads, head_dim]
|
||||
k_chunks.append(k_block.to("cuda", non_blocking=True))
|
||||
v_chunks.append(v_block.to("cuda", non_blocking=True))
|
||||
|
||||
# Concatenate all chunks
|
||||
k_all = torch.cat(k_chunks, dim=0) # [total_tokens, kv_heads, head_dim]
|
||||
v_all = torch.cat(v_chunks, dim=0)
|
||||
|
||||
# Add batch dimension
|
||||
k_all = k_all.unsqueeze(0) # [1, total_tokens, kv_heads, head_dim]
|
||||
v_all = v_all.unsqueeze(0)
|
||||
|
||||
return k_all, v_all
|
||||
|
||||
def get_gpu_block_tables(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Get GPU slot tables for sequences.
|
||||
|
||||
Returns GPU slot IDs, which may differ from logical block IDs.
|
||||
"""
|
||||
result = []
|
||||
for seq in seqs:
|
||||
gpu_table = []
|
||||
for logical_id in seq.block_table:
|
||||
block = self.logical_blocks[logical_id]
|
||||
assert block.location == BlockLocation.GPU, (
|
||||
f"Block {logical_id} not on GPU (location={block.location})"
|
||||
)
|
||||
gpu_table.append(block.gpu_slot)
|
||||
result.append(gpu_table)
|
||||
return result
|
||||
|
||||
def post_attention_cleanup(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
is_prefill: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Cleanup after attention.
|
||||
|
||||
Clear pending loads and optionally proactive offload.
|
||||
"""
|
||||
self.pending_gpu_loads.clear()
|
||||
|
||||
# ========== Chunked Prefill Support ==========
|
||||
|
||||
def needs_chunked_prefill(self, seq: Sequence) -> bool:
|
||||
"""
|
||||
Check if sequence needs chunked prefill.
|
||||
|
||||
Returns True if there are unprefilled blocks that are on CPU.
|
||||
This indicates we need to process in chunks because not all blocks fit on GPU.
|
||||
"""
|
||||
for logical_id in seq.block_table:
|
||||
if logical_id not in self.prefilled_blocks:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_gpu_block_count(self, seq: Sequence) -> int:
|
||||
"""Get number of blocks currently on GPU for this sequence."""
|
||||
count = 0
|
||||
for logical_id in seq.block_table:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.GPU:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
def get_prefill_chunk_info(self, seq: Sequence) -> Tuple[int, int, List[int]]:
|
||||
"""
|
||||
Get information for current prefill chunk.
|
||||
|
||||
Returns:
|
||||
(start_block_idx, end_block_idx, gpu_block_ids)
|
||||
- start_block_idx: First block index in this chunk
|
||||
- end_block_idx: Last block index (exclusive) in this chunk
|
||||
- gpu_block_ids: GPU slot IDs for blocks in this chunk
|
||||
"""
|
||||
start_idx = -1
|
||||
end_idx = -1
|
||||
gpu_block_ids = []
|
||||
|
||||
for i, logical_id in enumerate(seq.block_table):
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.GPU:
|
||||
if start_idx == -1:
|
||||
start_idx = i
|
||||
end_idx = i + 1
|
||||
gpu_block_ids.append(block.gpu_slot)
|
||||
elif start_idx != -1:
|
||||
# Found CPU block after GPU blocks - stop here
|
||||
break
|
||||
|
||||
if start_idx == -1:
|
||||
return (0, 0, [])
|
||||
|
||||
return (start_idx, end_idx, gpu_block_ids)
|
||||
|
||||
def complete_prefill_chunk(self, seq: Sequence) -> bool:
|
||||
"""
|
||||
Complete a prefill chunk: mark blocks as prefilled, offload to CPU, load next chunk.
|
||||
|
||||
Returns:
|
||||
True if there are more chunks to process, False if done.
|
||||
"""
|
||||
# Find blocks currently on GPU that were just prefilled
|
||||
gpu_blocks_to_offload = []
|
||||
for logical_id in seq.block_table:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.GPU and logical_id not in self.prefilled_blocks:
|
||||
# Mark as prefilled
|
||||
self.prefilled_blocks.add(logical_id)
|
||||
gpu_blocks_to_offload.append(logical_id)
|
||||
|
||||
# Offload prefilled GPU blocks to CPU
|
||||
for logical_id in gpu_blocks_to_offload:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if not self.free_cpu_blocks:
|
||||
raise RuntimeError("No free CPU blocks for offload")
|
||||
|
||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||
|
||||
# Async offload all layers
|
||||
for layer_id in range(self.offload_engine.num_layers):
|
||||
self.offload_engine.offload_block_async(
|
||||
layer_id=layer_id,
|
||||
gpu_block_id=block.gpu_slot,
|
||||
cpu_block_id=cpu_block_id,
|
||||
)
|
||||
|
||||
# Update mappings
|
||||
self.free_gpu_slots.append(block.gpu_slot)
|
||||
del self.gpu_slot_to_logical[block.gpu_slot]
|
||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||
|
||||
block.location = BlockLocation.CPU
|
||||
block.cpu_block_id = cpu_block_id
|
||||
block.gpu_slot = -1
|
||||
|
||||
# Wait for offload to complete
|
||||
self.offload_engine.wait_all_transfers()
|
||||
|
||||
# Find next UNPREFILLED CPU blocks and bring them to GPU
|
||||
cpu_blocks_to_load = []
|
||||
for logical_id in seq.block_table:
|
||||
if logical_id in self.prefilled_blocks:
|
||||
continue # Skip already prefilled
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
if len(cpu_blocks_to_load) >= self.num_gpu_slots:
|
||||
break # GPU is full
|
||||
cpu_blocks_to_load.append(logical_id)
|
||||
|
||||
if not cpu_blocks_to_load:
|
||||
return False # All blocks have been prefilled
|
||||
|
||||
# Load unprefilled CPU blocks to GPU
|
||||
for logical_id in cpu_blocks_to_load:
|
||||
block = self.logical_blocks[logical_id]
|
||||
gpu_slot = self.free_gpu_slots.popleft()
|
||||
|
||||
# Note: We're NOT prefetching existing data - these blocks are being
|
||||
# loaded for the first time, so we just need to assign GPU slots
|
||||
# The model will write new KV cache data to these slots
|
||||
|
||||
# Update mappings
|
||||
self.free_cpu_blocks.append(block.cpu_block_id)
|
||||
del self.cpu_block_to_logical[block.cpu_block_id]
|
||||
self.gpu_slot_to_logical[gpu_slot] = logical_id
|
||||
|
||||
block.location = BlockLocation.GPU
|
||||
block.gpu_slot = gpu_slot
|
||||
block.cpu_block_id = -1
|
||||
|
||||
return True # More chunks to process
|
||||
|
||||
def get_gpu_block_tables_partial(
|
||||
self,
|
||||
seqs: List[Sequence],
|
||||
) -> List[Tuple[List[int], int, int]]:
|
||||
"""
|
||||
Get GPU block tables for chunked prefill.
|
||||
|
||||
Returns list of (gpu_block_ids, start_block_idx, end_block_idx) per sequence.
|
||||
Only includes blocks that are currently on GPU AND haven't been prefilled yet.
|
||||
"""
|
||||
result = []
|
||||
for seq in seqs:
|
||||
gpu_table = []
|
||||
start_idx = -1
|
||||
end_idx = -1
|
||||
|
||||
for i, logical_id in enumerate(seq.block_table):
|
||||
# Skip already prefilled blocks
|
||||
if logical_id in self.prefilled_blocks:
|
||||
continue
|
||||
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.GPU:
|
||||
if start_idx == -1:
|
||||
start_idx = i
|
||||
end_idx = i + 1
|
||||
gpu_table.append(block.gpu_slot)
|
||||
elif start_idx != -1:
|
||||
# Stop at first non-GPU block after GPU blocks
|
||||
break
|
||||
|
||||
if start_idx == -1:
|
||||
start_idx = 0
|
||||
end_idx = 0
|
||||
|
||||
result.append((gpu_table, start_idx, end_idx))
|
||||
return result
|
||||
|
||||
def get_prefilled_cpu_blocks(self, seq: Sequence) -> List[int]:
|
||||
"""
|
||||
Get list of CPU block IDs for blocks that have been prefilled.
|
||||
|
||||
Used for loading previous KV during chunked prefill.
|
||||
|
||||
Returns:
|
||||
List of CPU block IDs in sequence order
|
||||
"""
|
||||
cpu_blocks = []
|
||||
for logical_id in seq.block_table:
|
||||
if logical_id in self.prefilled_blocks:
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_blocks.append(block.cpu_block_id)
|
||||
return cpu_blocks
|
||||
|
||||
def load_prev_kv_for_layer(
|
||||
self,
|
||||
seq: Sequence,
|
||||
layer_id: int,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Load previous prefilled KV from CPU for a specific layer.
|
||||
|
||||
This concatenates KV from all previously prefilled blocks for use
|
||||
during chunked prefill attention.
|
||||
|
||||
Args:
|
||||
seq: Sequence to load KV for
|
||||
layer_id: Layer index
|
||||
|
||||
Returns:
|
||||
(k, v) tensors with shape [1, total_prev_tokens, kv_heads, head_dim]
|
||||
or (None, None) if no previous KV exists
|
||||
"""
|
||||
cpu_blocks = self.get_prefilled_cpu_blocks(seq)
|
||||
if not cpu_blocks:
|
||||
return None, None
|
||||
|
||||
k_chunks = []
|
||||
v_chunks = []
|
||||
|
||||
for cpu_block_id in cpu_blocks:
|
||||
k, v = self.offload_engine.get_cpu_block(layer_id, cpu_block_id)
|
||||
# k, v shape: [block_size, kv_heads, head_dim]
|
||||
k_chunks.append(k)
|
||||
v_chunks.append(v)
|
||||
|
||||
# Concatenate all chunks
|
||||
k_prev = torch.cat(k_chunks, dim=0) # [total_prev_tokens, kv_heads, head_dim]
|
||||
v_prev = torch.cat(v_chunks, dim=0)
|
||||
|
||||
# Move to GPU and add batch dimension
|
||||
k_prev = k_prev.to("cuda", non_blocking=True).unsqueeze(0) # [1, tokens, heads, dim]
|
||||
v_prev = v_prev.to("cuda", non_blocking=True).unsqueeze(0)
|
||||
|
||||
return k_prev, v_prev
|
||||
|
||||
def get_chunk_start_position(self, seq: Sequence) -> int:
|
||||
"""
|
||||
Get the starting token position for the current chunk.
|
||||
|
||||
This is the total number of tokens in previously prefilled blocks.
|
||||
|
||||
Returns:
|
||||
Token position offset for current chunk
|
||||
"""
|
||||
pos = 0
|
||||
for logical_id in seq.block_table:
|
||||
if logical_id in self.prefilled_blocks:
|
||||
# Full block's worth of tokens
|
||||
pos += self._block_size
|
||||
else:
|
||||
break
|
||||
return pos
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"HybridKVCacheManager(\n"
|
||||
f" num_gpu_slots={self.num_gpu_slots},\n"
|
||||
f" num_cpu_blocks={self.num_cpu_blocks},\n"
|
||||
f" block_size={self._block_size},\n"
|
||||
f" free_logical={len(self.free_logical_ids)},\n"
|
||||
f" free_gpu={len(self.free_gpu_slots)},\n"
|
||||
f" free_cpu={len(self.free_cpu_blocks)},\n"
|
||||
f" policy={self.policy}\n"
|
||||
f")"
|
||||
)
|
||||
190
nanovllm/kvcache/kernels.py
Normal file
190
nanovllm/kvcache/kernels.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""
|
||||
Triton kernels for CPU-GPU KV cache transfer.
|
||||
|
||||
These kernels are designed to be CUDA Graph compatible:
|
||||
- All tensor addresses are fixed at graph capture time
|
||||
- Only the content of index tensors changes between replays
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gathered_copy_kernel(
|
||||
src_ptr, # Source tensor base pointer (CPU pinned or GPU)
|
||||
dst_ptr, # Destination tensor base pointer (GPU)
|
||||
indices_ptr, # Gather indices [num_dst_blocks]
|
||||
num_dst_blocks, # Number of destination blocks
|
||||
block_numel: tl.constexpr, # Elements per block (block_size * kv_heads * head_dim)
|
||||
BLOCK_SIZE: tl.constexpr = 1024,
|
||||
):
|
||||
"""
|
||||
Gathered copy kernel: dst[i] = src[indices[i]]
|
||||
|
||||
Each program instance handles one destination block.
|
||||
The indices tensor specifies which source block to copy from.
|
||||
|
||||
This kernel is CUDA Graph compatible because:
|
||||
- src_ptr, dst_ptr, indices_ptr addresses are fixed
|
||||
- Only indices content changes between graph replays
|
||||
|
||||
Args:
|
||||
src_ptr: Base pointer to source blocks [num_src_blocks, block_numel]
|
||||
dst_ptr: Base pointer to destination blocks [num_dst_blocks, block_numel]
|
||||
indices_ptr: Gather indices [num_dst_blocks], each value is a source block index
|
||||
num_dst_blocks: Number of destination blocks to copy
|
||||
block_numel: Number of elements per block
|
||||
BLOCK_SIZE: Triton block size for parallelization
|
||||
"""
|
||||
dst_block_idx = tl.program_id(0)
|
||||
|
||||
# Skip if out of range
|
||||
if dst_block_idx >= num_dst_blocks:
|
||||
return
|
||||
|
||||
# Load source block index from indices tensor
|
||||
src_block_idx = tl.load(indices_ptr + dst_block_idx)
|
||||
|
||||
# Skip if index is -1 (invalid/no-op marker)
|
||||
if src_block_idx < 0:
|
||||
return
|
||||
|
||||
# Calculate base offsets
|
||||
src_base = src_block_idx * block_numel
|
||||
dst_base = dst_block_idx * block_numel
|
||||
|
||||
# Copy block data in chunks of BLOCK_SIZE
|
||||
for start in range(0, block_numel, BLOCK_SIZE):
|
||||
offsets = start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < block_numel
|
||||
|
||||
# Load from source and store to destination
|
||||
data = tl.load(src_ptr + src_base + offsets, mask=mask)
|
||||
tl.store(dst_ptr + dst_base + offsets, data, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gathered_copy_kv_kernel(
|
||||
k_src_ptr, # K cache source [num_src_blocks, block_size, kv_heads, head_dim]
|
||||
v_src_ptr, # V cache source
|
||||
k_dst_ptr, # K cache destination
|
||||
v_dst_ptr, # V cache destination
|
||||
indices_ptr, # Gather indices [num_dst_blocks]
|
||||
num_dst_blocks, # Number of destination blocks
|
||||
block_numel: tl.constexpr, # Elements per block
|
||||
BLOCK_SIZE: tl.constexpr = 1024,
|
||||
):
|
||||
"""
|
||||
Gathered copy for both K and V caches simultaneously.
|
||||
|
||||
More efficient than calling gathered_copy_kernel twice because:
|
||||
- Single kernel launch overhead
|
||||
- Better memory access patterns when K and V are accessed together
|
||||
"""
|
||||
dst_block_idx = tl.program_id(0)
|
||||
|
||||
if dst_block_idx >= num_dst_blocks:
|
||||
return
|
||||
|
||||
src_block_idx = tl.load(indices_ptr + dst_block_idx)
|
||||
|
||||
if src_block_idx < 0:
|
||||
return
|
||||
|
||||
src_base = src_block_idx * block_numel
|
||||
dst_base = dst_block_idx * block_numel
|
||||
|
||||
for start in range(0, block_numel, BLOCK_SIZE):
|
||||
offsets = start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < block_numel
|
||||
|
||||
# Copy K cache
|
||||
k_data = tl.load(k_src_ptr + src_base + offsets, mask=mask)
|
||||
tl.store(k_dst_ptr + dst_base + offsets, k_data, mask=mask)
|
||||
|
||||
# Copy V cache
|
||||
v_data = tl.load(v_src_ptr + src_base + offsets, mask=mask)
|
||||
tl.store(v_dst_ptr + dst_base + offsets, v_data, mask=mask)
|
||||
|
||||
|
||||
def gathered_copy(
|
||||
src: torch.Tensor,
|
||||
dst: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Perform gathered copy: dst[i] = src[indices[i]]
|
||||
|
||||
Args:
|
||||
src: Source tensor [num_src_blocks, ...]
|
||||
dst: Destination tensor [num_dst_blocks, ...]
|
||||
indices: Index tensor [num_dst_blocks], dtype=int64
|
||||
-1 means skip (no-op)
|
||||
|
||||
Note:
|
||||
- src can be on CPU (pinned memory) or GPU
|
||||
- dst must be on GPU
|
||||
- indices must be on GPU
|
||||
- All shapes after first dimension must match
|
||||
"""
|
||||
assert dst.is_cuda, "Destination must be on GPU"
|
||||
assert indices.is_cuda, "Indices must be on GPU"
|
||||
assert src.shape[1:] == dst.shape[1:], "Shape mismatch after first dimension"
|
||||
|
||||
num_dst_blocks = dst.shape[0]
|
||||
block_numel = dst[0].numel()
|
||||
|
||||
# Flatten for kernel
|
||||
src_flat = src.view(src.shape[0], -1)
|
||||
dst_flat = dst.view(dst.shape[0], -1)
|
||||
|
||||
grid = (num_dst_blocks,)
|
||||
gathered_copy_kernel[grid](
|
||||
src_flat,
|
||||
dst_flat,
|
||||
indices,
|
||||
num_dst_blocks,
|
||||
block_numel=block_numel,
|
||||
)
|
||||
|
||||
|
||||
def gathered_copy_kv(
|
||||
k_src: torch.Tensor,
|
||||
v_src: torch.Tensor,
|
||||
k_dst: torch.Tensor,
|
||||
v_dst: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
) -> None:
|
||||
"""
|
||||
Perform gathered copy for both K and V caches.
|
||||
|
||||
Args:
|
||||
k_src, v_src: Source K/V caches [num_src_blocks, block_size, kv_heads, head_dim]
|
||||
k_dst, v_dst: Destination K/V caches [num_dst_blocks, block_size, kv_heads, head_dim]
|
||||
indices: Index tensor [num_dst_blocks], dtype=int64
|
||||
"""
|
||||
assert k_dst.is_cuda and v_dst.is_cuda, "Destinations must be on GPU"
|
||||
assert indices.is_cuda, "Indices must be on GPU"
|
||||
assert k_src.shape[1:] == k_dst.shape[1:], "K shape mismatch"
|
||||
assert v_src.shape[1:] == v_dst.shape[1:], "V shape mismatch"
|
||||
|
||||
num_dst_blocks = k_dst.shape[0]
|
||||
block_numel = k_dst[0].numel()
|
||||
|
||||
k_src_flat = k_src.view(k_src.shape[0], -1)
|
||||
v_src_flat = v_src.view(v_src.shape[0], -1)
|
||||
k_dst_flat = k_dst.view(k_dst.shape[0], -1)
|
||||
v_dst_flat = v_dst.view(v_dst.shape[0], -1)
|
||||
|
||||
grid = (num_dst_blocks,)
|
||||
gathered_copy_kv_kernel[grid](
|
||||
k_src_flat,
|
||||
v_src_flat,
|
||||
k_dst_flat,
|
||||
v_dst_flat,
|
||||
indices,
|
||||
num_dst_blocks,
|
||||
block_numel=block_numel,
|
||||
)
|
||||
400
nanovllm/kvcache/offload_engine.py
Normal file
400
nanovllm/kvcache/offload_engine.py
Normal file
@@ -0,0 +1,400 @@
|
||||
"""
|
||||
High-performance CPU-GPU KV cache transfer engine.
|
||||
|
||||
Key design principles for CUDA Graph compatibility:
|
||||
1. All tensor addresses are fixed at initialization
|
||||
2. Only index tensor contents change between graph replays
|
||||
3. Supports both async transfer (for prefill) and graph-based transfer (for decode)
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nanovllm.kvcache.kernels import gathered_copy_kv
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferEvent:
|
||||
"""Tracks a pending async transfer."""
|
||||
event: torch.cuda.Event
|
||||
layer_id: int
|
||||
src_block_id: int
|
||||
dst_block_id: int
|
||||
direction: str # "h2d" or "d2h"
|
||||
|
||||
|
||||
class OffloadEngine:
|
||||
"""
|
||||
High-performance CPU-GPU async transfer engine for KV cache offloading.
|
||||
|
||||
Memory layout:
|
||||
- GPU cache: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
- CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned)
|
||||
- Gather indices: [num_layers, num_gpu_blocks] (fixed address, variable content)
|
||||
|
||||
CUDA Graph compatibility:
|
||||
- gathered_h2d_layer() can be captured into CUDA graphs
|
||||
- update_gather_indices() is called outside graphs to prepare indices
|
||||
- All tensor addresses remain fixed across graph replays
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_layers: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_dim: int,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_streams: int = 4,
|
||||
):
|
||||
self.num_layers = num_layers
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.num_cpu_blocks = num_cpu_blocks
|
||||
self.block_size = block_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.head_dim = head_dim
|
||||
self.dtype = dtype
|
||||
self.kv_dim = num_kv_heads * head_dim
|
||||
self.block_numel = block_size * self.kv_dim
|
||||
|
||||
# ========== Fixed-address GPU KV cache ==========
|
||||
# Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
self.k_cache_gpu = torch.empty(
|
||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
self.v_cache_gpu = torch.empty(
|
||||
num_layers, num_gpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cuda"
|
||||
)
|
||||
|
||||
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
||||
self.k_cache_cpu = torch.empty(
|
||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cpu", pin_memory=True
|
||||
)
|
||||
self.v_cache_cpu = torch.empty(
|
||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=dtype, device="cpu", pin_memory=True
|
||||
)
|
||||
|
||||
# ========== Fixed-address gather indices (content is variable) ==========
|
||||
# gather_indices[layer][i] = CPU block id to copy to GPU slot i
|
||||
# -1 means no-op (skip this slot)
|
||||
self.gather_indices_cpu = torch.empty(
|
||||
num_layers, num_gpu_blocks,
|
||||
dtype=torch.int64, device="cpu", pin_memory=True
|
||||
)
|
||||
self.gather_indices_cpu.fill_(-1)
|
||||
self.gather_indices_gpu = torch.full(
|
||||
(num_layers, num_gpu_blocks), -1,
|
||||
dtype=torch.int64, device="cuda"
|
||||
)
|
||||
|
||||
# ========== Transfer streams for async operations ==========
|
||||
self.transfer_streams = [torch.cuda.Stream() for _ in range(num_streams)]
|
||||
self.compute_stream = torch.cuda.current_stream()
|
||||
self._stream_idx = 0
|
||||
|
||||
# ========== Event tracking for async transfers ==========
|
||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||
|
||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
||||
"""Round-robin stream selection for parallel transfers."""
|
||||
stream = self.transfer_streams[self._stream_idx]
|
||||
self._stream_idx = (self._stream_idx + 1) % len(self.transfer_streams)
|
||||
return stream
|
||||
|
||||
# ========== CUDA Graph compatible methods ==========
|
||||
|
||||
def gathered_h2d_layer(self, layer_id: int) -> None:
|
||||
"""
|
||||
Execute gathered H2D copy for a single layer.
|
||||
|
||||
This method is CUDA Graph compatible - can be captured into a graph.
|
||||
Before calling, update_gather_indices() must be called to set up
|
||||
which CPU blocks to copy to which GPU slots.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index to transfer
|
||||
"""
|
||||
gathered_copy_kv(
|
||||
k_src=self.k_cache_cpu[layer_id],
|
||||
v_src=self.v_cache_cpu[layer_id],
|
||||
k_dst=self.k_cache_gpu[layer_id],
|
||||
v_dst=self.v_cache_gpu[layer_id],
|
||||
indices=self.gather_indices_gpu[layer_id],
|
||||
)
|
||||
|
||||
def gathered_h2d_all_layers(self) -> None:
|
||||
"""
|
||||
Execute gathered H2D copy for all layers.
|
||||
|
||||
CUDA Graph compatible - can be captured into a single graph.
|
||||
"""
|
||||
for layer_id in range(self.num_layers):
|
||||
self.gathered_h2d_layer(layer_id)
|
||||
|
||||
def update_gather_indices(
|
||||
self,
|
||||
layer_id: int,
|
||||
mappings: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Update gather indices for a layer (call OUTSIDE CUDA graph).
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
mappings: List of (cpu_block_id, gpu_slot) tuples
|
||||
Only these slots will be updated; others keep their values
|
||||
"""
|
||||
for cpu_block_id, gpu_slot in mappings:
|
||||
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
|
||||
|
||||
# Async copy to GPU
|
||||
self.gather_indices_gpu[layer_id].copy_(
|
||||
self.gather_indices_cpu[layer_id],
|
||||
non_blocking=True
|
||||
)
|
||||
|
||||
def update_gather_indices_all_layers(
|
||||
self,
|
||||
mappings_per_layer: List[List[Tuple[int, int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Update gather indices for all layers.
|
||||
|
||||
Args:
|
||||
mappings_per_layer: mappings_per_layer[layer_id] = [(cpu_block_id, gpu_slot), ...]
|
||||
"""
|
||||
for layer_id, mappings in enumerate(mappings_per_layer):
|
||||
for cpu_block_id, gpu_slot in mappings:
|
||||
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
|
||||
|
||||
# Batch copy all layers
|
||||
self.gather_indices_gpu.copy_(self.gather_indices_cpu, non_blocking=True)
|
||||
|
||||
def clear_gather_indices(self, layer_id: Optional[int] = None) -> None:
|
||||
"""
|
||||
Clear gather indices (set all to -1, meaning no-op).
|
||||
|
||||
Args:
|
||||
layer_id: If provided, clear only this layer; otherwise clear all
|
||||
"""
|
||||
if layer_id is not None:
|
||||
self.gather_indices_cpu[layer_id].fill_(-1)
|
||||
self.gather_indices_gpu[layer_id].fill_(-1)
|
||||
else:
|
||||
self.gather_indices_cpu.fill_(-1)
|
||||
self.gather_indices_gpu.fill_(-1)
|
||||
|
||||
# ========== Async transfer methods (for prefill, outside CUDA graph) ==========
|
||||
|
||||
def prefetch_block_async(
|
||||
self,
|
||||
layer_id: int,
|
||||
cpu_block_id: int,
|
||||
gpu_block_id: int,
|
||||
) -> torch.cuda.Event:
|
||||
"""
|
||||
Async prefetch a single block from CPU to GPU.
|
||||
|
||||
For use in prefill phase where CUDA graphs are not used.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
cpu_block_id: Source block in CPU cache
|
||||
gpu_block_id: Destination slot in GPU cache
|
||||
|
||||
Returns:
|
||||
CUDA event that signals completion
|
||||
"""
|
||||
stream = self._get_next_stream()
|
||||
event = torch.cuda.Event()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
# K cache
|
||||
self.k_cache_gpu[layer_id, gpu_block_id].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
# V cache
|
||||
self.v_cache_gpu[layer_id, gpu_block_id].copy_(
|
||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
event.record()
|
||||
|
||||
self.pending_events[(layer_id, gpu_block_id)] = event
|
||||
return event
|
||||
|
||||
def prefetch_blocks_batch_async(
|
||||
self,
|
||||
transfers: List[Tuple[int, int, int]], # [(layer_id, cpu_block_id, gpu_block_id), ...]
|
||||
) -> List[torch.cuda.Event]:
|
||||
"""
|
||||
Batch async prefetch multiple blocks.
|
||||
|
||||
Args:
|
||||
transfers: List of (layer_id, cpu_block_id, gpu_block_id) tuples
|
||||
|
||||
Returns:
|
||||
List of CUDA events for each transfer
|
||||
"""
|
||||
events = []
|
||||
for layer_id, cpu_block_id, gpu_block_id in transfers:
|
||||
event = self.prefetch_block_async(layer_id, cpu_block_id, gpu_block_id)
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
def offload_block_async(
|
||||
self,
|
||||
layer_id: int,
|
||||
gpu_block_id: int,
|
||||
cpu_block_id: int,
|
||||
) -> torch.cuda.Event:
|
||||
"""
|
||||
Async offload a block from GPU to CPU.
|
||||
|
||||
Args:
|
||||
layer_id: Layer index
|
||||
gpu_block_id: Source slot in GPU cache
|
||||
cpu_block_id: Destination block in CPU cache
|
||||
|
||||
Returns:
|
||||
CUDA event that signals completion
|
||||
"""
|
||||
stream = self._get_next_stream()
|
||||
event = torch.cuda.Event()
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
# Wait for any compute using this block
|
||||
stream.wait_stream(self.compute_stream)
|
||||
|
||||
# K cache
|
||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[layer_id, gpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
# V cache
|
||||
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[layer_id, gpu_block_id],
|
||||
non_blocking=True
|
||||
)
|
||||
event.record()
|
||||
|
||||
return event
|
||||
|
||||
def offload_blocks_batch_async(
|
||||
self,
|
||||
transfers: List[Tuple[int, int, int]], # [(layer_id, gpu_block_id, cpu_block_id), ...]
|
||||
) -> List[torch.cuda.Event]:
|
||||
"""
|
||||
Batch async offload multiple blocks.
|
||||
|
||||
Args:
|
||||
transfers: List of (layer_id, gpu_block_id, cpu_block_id) tuples
|
||||
|
||||
Returns:
|
||||
List of CUDA events
|
||||
"""
|
||||
events = []
|
||||
for layer_id, gpu_block_id, cpu_block_id in transfers:
|
||||
event = self.offload_block_async(layer_id, gpu_block_id, cpu_block_id)
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
# ========== Synchronization methods ==========
|
||||
|
||||
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:
|
||||
"""Wait for a specific block's transfer to complete."""
|
||||
key = (layer_id, gpu_block_id)
|
||||
if key in self.pending_events:
|
||||
self.pending_events[key].synchronize()
|
||||
del self.pending_events[key]
|
||||
|
||||
def wait_all_transfers(self) -> None:
|
||||
"""Wait for all pending transfers to complete."""
|
||||
for stream in self.transfer_streams:
|
||||
stream.synchronize()
|
||||
self.pending_events.clear()
|
||||
|
||||
def sync_indices(self) -> None:
|
||||
"""Synchronize to ensure all index updates are complete."""
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# ========== Cache access methods ==========
|
||||
|
||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get GPU K/V cache tensors for a specific layer.
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache) tensors for the layer
|
||||
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
"""
|
||||
return self.k_cache_gpu[layer_id], self.v_cache_gpu[layer_id]
|
||||
|
||||
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get full GPU K/V cache tensors.
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache) tensors
|
||||
Shape: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||
"""
|
||||
return self.k_cache_gpu, self.v_cache_gpu
|
||||
|
||||
def get_cpu_block(
|
||||
self,
|
||||
layer_id: int,
|
||||
cpu_block_id: int,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Get a specific CPU block's K/V cache.
|
||||
|
||||
Returns:
|
||||
(k_cache, v_cache) for the block
|
||||
Shape: [block_size, kv_heads, head_dim]
|
||||
"""
|
||||
return (
|
||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||
)
|
||||
|
||||
# ========== Memory info ==========
|
||||
|
||||
def gpu_memory_bytes(self) -> int:
|
||||
"""Total GPU memory used by KV caches."""
|
||||
return (
|
||||
self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() +
|
||||
self.v_cache_gpu.numel() * self.v_cache_gpu.element_size() +
|
||||
self.gather_indices_gpu.numel() * self.gather_indices_gpu.element_size()
|
||||
)
|
||||
|
||||
def cpu_memory_bytes(self) -> int:
|
||||
"""Total CPU memory used by KV caches."""
|
||||
return (
|
||||
self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() +
|
||||
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size() +
|
||||
self.gather_indices_cpu.numel() * self.gather_indices_cpu.element_size()
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"OffloadEngine(\n"
|
||||
f" num_layers={self.num_layers},\n"
|
||||
f" num_gpu_blocks={self.num_gpu_blocks},\n"
|
||||
f" num_cpu_blocks={self.num_cpu_blocks},\n"
|
||||
f" block_size={self.block_size},\n"
|
||||
f" kv_heads={self.num_kv_heads},\n"
|
||||
f" head_dim={self.head_dim},\n"
|
||||
f" dtype={self.dtype},\n"
|
||||
f" gpu_memory={self.gpu_memory_bytes() / 1024**2:.1f}MB,\n"
|
||||
f" cpu_memory={self.cpu_memory_bytes() / 1024**2:.1f}MB\n"
|
||||
f")"
|
||||
)
|
||||
51
nanovllm/kvcache/policies/__init__.py
Normal file
51
nanovllm/kvcache/policies/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Eviction policy plugins for KV cache offloading.
|
||||
|
||||
Users can create custom policies by subclassing EvictionPolicy
|
||||
and specifying the full class path in config.offload_policy.
|
||||
"""
|
||||
|
||||
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
|
||||
from nanovllm.kvcache.policies.lru_policy import LRUPolicy
|
||||
from nanovllm.kvcache.policies.fifo_policy import FIFOPolicy
|
||||
|
||||
# Built-in policy registry
|
||||
BUILTIN_POLICIES = {
|
||||
"lru": LRUPolicy,
|
||||
"fifo": FIFOPolicy,
|
||||
}
|
||||
|
||||
|
||||
def get_policy(policy_name: str) -> EvictionPolicy:
|
||||
"""
|
||||
Get an eviction policy instance by name or class path.
|
||||
|
||||
Args:
|
||||
policy_name: Either a built-in name ("lru", "fifo") or
|
||||
a full class path ("mymodule.MyPolicy")
|
||||
|
||||
Returns:
|
||||
EvictionPolicy instance
|
||||
"""
|
||||
# Check built-in policies first
|
||||
if policy_name.lower() in BUILTIN_POLICIES:
|
||||
return BUILTIN_POLICIES[policy_name.lower()]()
|
||||
|
||||
# Try to import custom policy
|
||||
try:
|
||||
module_path, class_name = policy_name.rsplit(".", 1)
|
||||
import importlib
|
||||
module = importlib.import_module(module_path)
|
||||
policy_class = getattr(module, class_name)
|
||||
if not issubclass(policy_class, EvictionPolicy):
|
||||
raise TypeError(f"{policy_name} is not a subclass of EvictionPolicy")
|
||||
return policy_class()
|
||||
except (ValueError, ImportError, AttributeError) as e:
|
||||
raise ValueError(
|
||||
f"Unknown policy '{policy_name}'. "
|
||||
f"Available built-in policies: {list(BUILTIN_POLICIES.keys())}. "
|
||||
f"For custom policies, use full class path: 'mymodule.MyPolicy'"
|
||||
) from e
|
||||
|
||||
|
||||
__all__ = ["EvictionPolicy", "LRUPolicy", "FIFOPolicy", "get_policy", "BUILTIN_POLICIES"]
|
||||
156
nanovllm/kvcache/policies/base_policy.py
Normal file
156
nanovllm/kvcache/policies/base_policy.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Base class for eviction policies.
|
||||
|
||||
Users can implement custom policies by subclassing EvictionPolicy
|
||||
and overriding the abstract methods.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Set, Optional
|
||||
|
||||
|
||||
class EvictionPolicy(ABC):
|
||||
"""
|
||||
Abstract base class for KV cache eviction policies.
|
||||
|
||||
An eviction policy determines which GPU blocks to evict to CPU
|
||||
when GPU memory is full and new blocks need to be allocated.
|
||||
|
||||
Lifecycle:
|
||||
1. on_block_allocated() - called when a new block is allocated
|
||||
2. on_block_access() - called each time a block is accessed (e.g., in attention)
|
||||
3. select_victim() - called when a block needs to be evicted
|
||||
4. on_block_evicted() - called after a block is evicted
|
||||
|
||||
Example custom policy:
|
||||
```python
|
||||
class MyCustomPolicy(EvictionPolicy):
|
||||
def __init__(self):
|
||||
self.priorities = {}
|
||||
|
||||
def on_block_allocated(self, block_id: int, step: int):
|
||||
self.priorities[block_id] = step
|
||||
|
||||
def on_block_access(self, block_id: int, step: int):
|
||||
# Custom access tracking
|
||||
pass
|
||||
|
||||
def select_victim(self, candidates: Set[int]) -> int:
|
||||
# Return block with lowest priority
|
||||
return min(candidates, key=lambda b: self.priorities.get(b, 0))
|
||||
|
||||
def on_block_evicted(self, block_id: int):
|
||||
self.priorities.pop(block_id, None)
|
||||
```
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def on_block_allocated(self, block_id: int, step: int) -> None:
|
||||
"""
|
||||
Called when a new block is allocated on GPU.
|
||||
|
||||
Args:
|
||||
block_id: The GPU block ID that was allocated
|
||||
step: Current inference step (monotonically increasing)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_block_access(self, block_id: int, step: int) -> None:
|
||||
"""
|
||||
Called when a block is accessed during attention computation.
|
||||
|
||||
Args:
|
||||
block_id: The GPU block ID being accessed
|
||||
step: Current inference step
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def select_victim(self, candidates: Set[int]) -> int:
|
||||
"""
|
||||
Select a block to evict from the candidate set.
|
||||
|
||||
This is called when GPU memory is full and a new block
|
||||
needs to be allocated. The returned block will be evicted
|
||||
to CPU.
|
||||
|
||||
Args:
|
||||
candidates: Set of GPU block IDs that can be evicted
|
||||
(blocks not currently being used)
|
||||
|
||||
Returns:
|
||||
Block ID to evict
|
||||
|
||||
Raises:
|
||||
ValueError: If candidates is empty
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def on_block_evicted(self, block_id: int) -> None:
|
||||
"""
|
||||
Called after a block is evicted from GPU to CPU.
|
||||
|
||||
Args:
|
||||
block_id: The GPU block ID that was evicted
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_block_prefetched(self, block_id: int, step: int) -> None:
|
||||
"""
|
||||
Called when a block is prefetched from CPU back to GPU.
|
||||
|
||||
Default implementation calls on_block_allocated().
|
||||
Override for custom behavior.
|
||||
|
||||
Args:
|
||||
block_id: The GPU block ID that was prefetched to
|
||||
step: Current inference step
|
||||
"""
|
||||
self.on_block_allocated(block_id, step)
|
||||
|
||||
def on_block_deallocated(self, block_id: int) -> None:
|
||||
"""
|
||||
Called when a block is fully deallocated (sequence finished).
|
||||
|
||||
Default implementation calls on_block_evicted().
|
||||
Override for custom behavior.
|
||||
|
||||
Args:
|
||||
block_id: The GPU block ID being deallocated
|
||||
"""
|
||||
self.on_block_evicted(block_id)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
Reset policy state.
|
||||
|
||||
Called when the inference engine is reset.
|
||||
Default implementation does nothing.
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_eviction_order(self, candidates: Set[int], count: int) -> list:
|
||||
"""
|
||||
Get multiple blocks to evict in order of priority.
|
||||
|
||||
Default implementation calls select_victim() repeatedly.
|
||||
Override for more efficient batch selection.
|
||||
|
||||
Args:
|
||||
candidates: Set of candidate block IDs
|
||||
count: Number of blocks to evict
|
||||
|
||||
Returns:
|
||||
List of block IDs to evict, in order
|
||||
"""
|
||||
result = []
|
||||
remaining = set(candidates)
|
||||
for _ in range(min(count, len(remaining))):
|
||||
if not remaining:
|
||||
break
|
||||
victim = self.select_victim(remaining)
|
||||
result.append(victim)
|
||||
remaining.remove(victim)
|
||||
return result
|
||||
101
nanovllm/kvcache/policies/fifo_policy.py
Normal file
101
nanovllm/kvcache/policies/fifo_policy.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
FIFO (First In, First Out) eviction policy.
|
||||
|
||||
Evicts the block that was allocated earliest.
|
||||
Simple policy that ignores access patterns.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Set
|
||||
|
||||
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
|
||||
|
||||
|
||||
class FIFOPolicy(EvictionPolicy):
|
||||
"""
|
||||
First In, First Out (FIFO) eviction policy.
|
||||
|
||||
Evicts blocks in the order they were allocated,
|
||||
regardless of access patterns.
|
||||
|
||||
Properties:
|
||||
- O(1) operations for all methods
|
||||
- Simple and predictable behavior
|
||||
- Good for streaming workloads where older data
|
||||
is naturally less relevant
|
||||
- Does not adapt to access patterns (unlike LRU)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# OrderedDict maintains insertion order
|
||||
# Key: block_id, Value: allocation_step
|
||||
# Oldest (first allocated) is at the front
|
||||
self.allocation_order: OrderedDict[int, int] = OrderedDict()
|
||||
|
||||
def on_block_allocated(self, block_id: int, step: int) -> None:
|
||||
"""Record allocation order (does not change on access)."""
|
||||
if block_id not in self.allocation_order:
|
||||
self.allocation_order[block_id] = step
|
||||
|
||||
def on_block_access(self, block_id: int, step: int) -> None:
|
||||
"""
|
||||
FIFO ignores access patterns.
|
||||
|
||||
This is the key difference from LRU - we don't
|
||||
update the position based on access.
|
||||
"""
|
||||
pass # Intentionally empty
|
||||
|
||||
def select_victim(self, candidates: Set[int]) -> int:
|
||||
"""
|
||||
Select the earliest allocated block from candidates.
|
||||
"""
|
||||
if not candidates:
|
||||
raise ValueError("Cannot select victim from empty candidate set")
|
||||
|
||||
# Iterate from oldest (front) to newest (back)
|
||||
for block_id in self.allocation_order:
|
||||
if block_id in candidates:
|
||||
return block_id
|
||||
|
||||
# Fallback: return any candidate
|
||||
return next(iter(candidates))
|
||||
|
||||
def on_block_evicted(self, block_id: int) -> None:
|
||||
"""Remove block from tracking."""
|
||||
self.allocation_order.pop(block_id, None)
|
||||
|
||||
def on_block_prefetched(self, block_id: int, step: int) -> None:
|
||||
"""
|
||||
When prefetched, treat as new allocation.
|
||||
|
||||
This moves the block to the end of the queue,
|
||||
giving it more time before eviction.
|
||||
"""
|
||||
# Remove old entry if exists
|
||||
self.allocation_order.pop(block_id, None)
|
||||
# Add as new allocation
|
||||
self.allocation_order[block_id] = step
|
||||
|
||||
def on_block_deallocated(self, block_id: int) -> None:
|
||||
"""Remove block from tracking."""
|
||||
self.allocation_order.pop(block_id, None)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all tracking data."""
|
||||
self.allocation_order.clear()
|
||||
|
||||
def get_eviction_order(self, candidates: Set[int], count: int) -> list:
|
||||
"""
|
||||
Get multiple blocks to evict in FIFO order.
|
||||
"""
|
||||
result = []
|
||||
for block_id in self.allocation_order:
|
||||
if block_id in candidates:
|
||||
result.append(block_id)
|
||||
if len(result) >= count:
|
||||
break
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"FIFOPolicy(tracked_blocks={len(self.allocation_order)})"
|
||||
93
nanovllm/kvcache/policies/lru_policy.py
Normal file
93
nanovllm/kvcache/policies/lru_policy.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
LRU (Least Recently Used) eviction policy.
|
||||
|
||||
Evicts the block that was accessed least recently.
|
||||
This is the default and recommended policy for most use cases.
|
||||
"""
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Set
|
||||
|
||||
from nanovllm.kvcache.policies.base_policy import EvictionPolicy
|
||||
|
||||
|
||||
class LRUPolicy(EvictionPolicy):
|
||||
"""
|
||||
Least Recently Used (LRU) eviction policy.
|
||||
|
||||
Maintains an ordered dictionary of block access times.
|
||||
When eviction is needed, selects the block that was
|
||||
accessed least recently.
|
||||
|
||||
Properties:
|
||||
- O(1) access tracking
|
||||
- O(n) victim selection in worst case, but typically fast
|
||||
due to OrderedDict iteration order
|
||||
- Good for workloads with temporal locality
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# OrderedDict maintains insertion/update order
|
||||
# Key: block_id, Value: last_access_step
|
||||
# Oldest (least recently used) is at the front
|
||||
self.access_order: OrderedDict[int, int] = OrderedDict()
|
||||
|
||||
def on_block_allocated(self, block_id: int, step: int) -> None:
|
||||
"""Record allocation as an access."""
|
||||
# Move to end (most recently used)
|
||||
self.access_order[block_id] = step
|
||||
self.access_order.move_to_end(block_id)
|
||||
|
||||
def on_block_access(self, block_id: int, step: int) -> None:
|
||||
"""Update access time and move to end."""
|
||||
if block_id in self.access_order:
|
||||
self.access_order[block_id] = step
|
||||
self.access_order.move_to_end(block_id)
|
||||
|
||||
def select_victim(self, candidates: Set[int]) -> int:
|
||||
"""
|
||||
Select the least recently used block from candidates.
|
||||
|
||||
Iterates from oldest to newest in access order,
|
||||
returns the first one that's in the candidate set.
|
||||
"""
|
||||
if not candidates:
|
||||
raise ValueError("Cannot select victim from empty candidate set")
|
||||
|
||||
# Iterate from oldest (front) to newest (back)
|
||||
for block_id in self.access_order:
|
||||
if block_id in candidates:
|
||||
return block_id
|
||||
|
||||
# Fallback: return any candidate (shouldn't happen normally)
|
||||
return next(iter(candidates))
|
||||
|
||||
def on_block_evicted(self, block_id: int) -> None:
|
||||
"""Remove block from tracking."""
|
||||
self.access_order.pop(block_id, None)
|
||||
|
||||
def on_block_deallocated(self, block_id: int) -> None:
|
||||
"""Remove block from tracking."""
|
||||
self.access_order.pop(block_id, None)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all tracking data."""
|
||||
self.access_order.clear()
|
||||
|
||||
def get_eviction_order(self, candidates: Set[int], count: int) -> list:
|
||||
"""
|
||||
Efficiently get multiple blocks to evict in LRU order.
|
||||
|
||||
Optimized for batch eviction - iterates through access_order
|
||||
once instead of calling select_victim() multiple times.
|
||||
"""
|
||||
result = []
|
||||
for block_id in self.access_order:
|
||||
if block_id in candidates:
|
||||
result.append(block_id)
|
||||
if len(result) >= count:
|
||||
break
|
||||
return result
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"LRUPolicy(tracked_blocks={len(self.access_order)})"
|
||||
@@ -55,21 +55,164 @@ class Attention(nn.Module):
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.k_cache = self.v_cache = torch.tensor([])
|
||||
# Layer ID set by model_runner after model creation
|
||||
self.layer_id: int = -1
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
context = get_context()
|
||||
k_cache, v_cache = self.k_cache, self.v_cache
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
|
||||
if context.is_prefill:
|
||||
if context.block_tables is not None: # prefix cache
|
||||
if context.is_chunked_prefill:
|
||||
# Chunked prefill: merge attention from previous KV
|
||||
o = self._chunked_prefill_attention(q, k, v, context)
|
||||
elif context.block_tables is not None: # prefix cache
|
||||
k, v = k_cache, v_cache
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
else:
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
else: # decode
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
if context.is_chunked_prefill:
|
||||
# Chunked decode: need to load all KV from CPU+GPU
|
||||
o = self._chunked_decode_attention(q, k, v, context)
|
||||
else:
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
return o
|
||||
|
||||
def _chunked_prefill_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention with chunked KV from CPU cache.
|
||||
|
||||
For chunked prefill:
|
||||
1. Load previous KV from CPU for this layer
|
||||
2. Compute attention against previous KV (no causal mask)
|
||||
3. Compute attention against current chunk's KV (causal)
|
||||
4. Merge results using online softmax
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
||||
total_tokens = q.shape[0]
|
||||
|
||||
# Reshape for flash attention: [batch, seq, heads, dim]
|
||||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||||
k_batched = k.unsqueeze(0)
|
||||
v_batched = v.unsqueeze(0)
|
||||
|
||||
accumulated_o = None
|
||||
accumulated_lse = None
|
||||
|
||||
# Load previous KV from CPU for this layer
|
||||
if context.offload_engine is not None and self.layer_id >= 0:
|
||||
# Get the kvcache_manager from context
|
||||
kvcache_manager = context.offload_engine
|
||||
|
||||
# For each sequence in the chunk, load previous KV
|
||||
# Currently assuming single sequence
|
||||
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
|
||||
prev_k, prev_v = kvcache_manager.load_prev_kv_for_layer(
|
||||
context.chunked_seq,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if prev_k is not None and prev_v is not None:
|
||||
# Compute attention against previous KV (no causal mask)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
prev_k,
|
||||
prev_v,
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # No causal mask for previous context
|
||||
)
|
||||
accumulated_o = prev_o
|
||||
accumulated_lse = prev_lse
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_batched,
|
||||
v_batched,
|
||||
softmax_scale=self.scale,
|
||||
causal=True, # Causal mask for current chunk
|
||||
)
|
||||
|
||||
# Merge with accumulated
|
||||
if accumulated_o is None:
|
||||
final_o = current_o
|
||||
else:
|
||||
final_o, _ = merge_attention_outputs(
|
||||
accumulated_o, accumulated_lse,
|
||||
current_o, current_lse,
|
||||
)
|
||||
|
||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
def _chunked_decode_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
context,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute decode attention with KV spread across CPU and GPU.
|
||||
|
||||
For decode with chunked KV:
|
||||
1. Load all KV for this layer from CPU+GPU
|
||||
2. Compute attention (1 query token vs all KV)
|
||||
3. Return output
|
||||
"""
|
||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse
|
||||
|
||||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||
# We need to attend to ALL previous tokens
|
||||
|
||||
# Load all KV for this layer
|
||||
if context.offload_engine is not None and self.layer_id >= 0:
|
||||
kvcache_manager = context.offload_engine
|
||||
|
||||
if hasattr(context, 'chunked_seq') and context.chunked_seq is not None:
|
||||
# Load all KV from both GPU and CPU for this layer
|
||||
k_all, v_all = kvcache_manager.load_all_kv_for_layer(
|
||||
context.chunked_seq,
|
||||
self.layer_id,
|
||||
)
|
||||
|
||||
if k_all is not None and v_all is not None:
|
||||
# q shape: [batch_size, num_heads, head_dim]
|
||||
# Need: [batch, seqlen, heads, dim]
|
||||
# Insert seqlen dimension at position 1
|
||||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||
|
||||
# k_all, v_all shape: [1, total_kv_tokens, kv_heads, head_dim]
|
||||
# Compute attention (no causal mask for decode - we want all KV)
|
||||
out, _ = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_all,
|
||||
v_all,
|
||||
softmax_scale=self.scale,
|
||||
causal=False, # No causal mask for decode
|
||||
)
|
||||
|
||||
# Output shape: [batch, 1, heads, dim] -> [batch, heads, dim]
|
||||
return out.squeeze(1)
|
||||
|
||||
# Fallback: shouldn't reach here
|
||||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, List, Tuple, Any
|
||||
import torch
|
||||
|
||||
|
||||
@@ -13,14 +14,60 @@ class Context:
|
||||
context_lens: torch.Tensor | None = None
|
||||
block_tables: torch.Tensor | None = None
|
||||
|
||||
# Chunked prefill support
|
||||
is_chunked_prefill: bool = False
|
||||
# Previous KV chunks info: List of (start_pos, end_pos) for blocks on CPU
|
||||
prev_kv_ranges: List[Tuple[int, int]] = field(default_factory=list)
|
||||
# Current chunk's position offset (for causal mask)
|
||||
chunk_offset: int = 0
|
||||
# Reference to kvcache manager for loading previous KV (HybridKVCacheManager)
|
||||
offload_engine: Any = None
|
||||
# Current layer's previous K/V chunks (loaded from CPU)
|
||||
# Set by model_runner before each layer's forward
|
||||
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
|
||||
# Current sequence being processed (for chunked prefill to load KV)
|
||||
chunked_seq: Any = None
|
||||
|
||||
|
||||
_CONTEXT = Context()
|
||||
|
||||
|
||||
def get_context():
|
||||
return _CONTEXT
|
||||
|
||||
def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None):
|
||||
|
||||
def set_context(
|
||||
is_prefill,
|
||||
cu_seqlens_q=None,
|
||||
cu_seqlens_k=None,
|
||||
max_seqlen_q=0,
|
||||
max_seqlen_k=0,
|
||||
slot_mapping=None,
|
||||
context_lens=None,
|
||||
block_tables=None,
|
||||
is_chunked_prefill=False,
|
||||
prev_kv_ranges=None,
|
||||
chunk_offset=0,
|
||||
offload_engine=None,
|
||||
chunked_seq=None,
|
||||
):
|
||||
global _CONTEXT
|
||||
_CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables)
|
||||
_CONTEXT = Context(
|
||||
is_prefill=is_prefill,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
is_chunked_prefill=is_chunked_prefill,
|
||||
prev_kv_ranges=prev_kv_ranges or [],
|
||||
chunk_offset=chunk_offset,
|
||||
offload_engine=offload_engine,
|
||||
chunked_seq=chunked_seq,
|
||||
)
|
||||
|
||||
|
||||
def reset_context():
|
||||
global _CONTEXT
|
||||
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test suite for nano-vllm KV cache offload."""
|
||||
169
tests/test_kernels.py
Normal file
169
tests/test_kernels.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Tests for Triton gathered copy kernels."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from nanovllm.kvcache.kernels import gathered_copy, gathered_copy_kv
|
||||
|
||||
|
||||
class TestGatheredCopy:
|
||||
"""Tests for gathered copy kernel."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_tensors(self):
|
||||
"""Create test tensors."""
|
||||
torch.cuda.manual_seed(42)
|
||||
num_src_blocks = 16
|
||||
num_dst_blocks = 8
|
||||
block_size = 256
|
||||
kv_dim = 64
|
||||
|
||||
src = torch.randn(num_src_blocks, block_size, kv_dim,
|
||||
dtype=torch.float16, device="cuda")
|
||||
dst = torch.zeros(num_dst_blocks, block_size, kv_dim,
|
||||
dtype=torch.float16, device="cuda")
|
||||
|
||||
# Indices: dst[i] = src[indices[i]]
|
||||
indices = torch.randint(0, num_src_blocks, (num_dst_blocks,),
|
||||
dtype=torch.int64, device="cuda")
|
||||
|
||||
return src, dst, indices
|
||||
|
||||
def test_basic_copy(self, setup_tensors):
|
||||
"""Test basic gathered copy."""
|
||||
src, dst, indices = setup_tensors
|
||||
|
||||
gathered_copy(src, dst, indices)
|
||||
|
||||
# Verify copy
|
||||
for i in range(len(indices)):
|
||||
src_idx = indices[i].item()
|
||||
assert torch.allclose(dst[i], src[src_idx]), f"Mismatch at index {i}"
|
||||
|
||||
def test_skip_negative_indices(self, setup_tensors):
|
||||
"""Test that negative indices are skipped."""
|
||||
src, dst, indices = setup_tensors
|
||||
|
||||
# Set some indices to -1
|
||||
indices[2] = -1
|
||||
indices[5] = -1
|
||||
|
||||
# Fill dst with a known value
|
||||
dst.fill_(999.0)
|
||||
|
||||
gathered_copy(src, dst, indices)
|
||||
|
||||
# Skipped slots should be unchanged
|
||||
assert (dst[2] == 999.0).all()
|
||||
assert (dst[5] == 999.0).all()
|
||||
|
||||
# Non-skipped slots should be copied
|
||||
for i in [0, 1, 3, 4, 6, 7]:
|
||||
src_idx = indices[i].item()
|
||||
assert torch.allclose(dst[i], src[src_idx])
|
||||
|
||||
def test_single_block(self):
|
||||
"""Test copying a single block."""
|
||||
src = torch.randn(4, 256, 64, dtype=torch.float16, device="cuda")
|
||||
dst = torch.zeros(1, 256, 64, dtype=torch.float16, device="cuda")
|
||||
indices = torch.tensor([2], dtype=torch.int64, device="cuda")
|
||||
|
||||
gathered_copy(src, dst, indices)
|
||||
|
||||
assert torch.allclose(dst[0], src[2])
|
||||
|
||||
|
||||
class TestGatheredCopyKV:
|
||||
"""Tests for gathered K/V cache copy kernel."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_kv_tensors(self):
|
||||
"""Create K/V test tensors."""
|
||||
torch.cuda.manual_seed(42)
|
||||
num_src_blocks = 16
|
||||
num_dst_blocks = 8
|
||||
block_size = 256
|
||||
num_kv_heads = 4
|
||||
head_dim = 64
|
||||
|
||||
k_src = torch.randn(num_src_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device="cuda")
|
||||
v_src = torch.randn(num_src_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device="cuda")
|
||||
k_dst = torch.zeros(num_dst_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device="cuda")
|
||||
v_dst = torch.zeros(num_dst_blocks, block_size, num_kv_heads, head_dim,
|
||||
dtype=torch.float16, device="cuda")
|
||||
|
||||
indices = torch.randint(0, num_src_blocks, (num_dst_blocks,),
|
||||
dtype=torch.int64, device="cuda")
|
||||
|
||||
return k_src, v_src, k_dst, v_dst, indices
|
||||
|
||||
def test_kv_copy(self, setup_kv_tensors):
|
||||
"""Test K/V gathered copy."""
|
||||
k_src, v_src, k_dst, v_dst, indices = setup_kv_tensors
|
||||
|
||||
gathered_copy_kv(k_src, v_src, k_dst, v_dst, indices)
|
||||
|
||||
# Verify copy
|
||||
for i in range(len(indices)):
|
||||
src_idx = indices[i].item()
|
||||
assert torch.allclose(k_dst[i], k_src[src_idx]), f"K mismatch at {i}"
|
||||
assert torch.allclose(v_dst[i], v_src[src_idx]), f"V mismatch at {i}"
|
||||
|
||||
def test_kv_skip_negative(self, setup_kv_tensors):
|
||||
"""Test that negative indices are skipped for K/V."""
|
||||
k_src, v_src, k_dst, v_dst, indices = setup_kv_tensors
|
||||
|
||||
indices[0] = -1
|
||||
k_dst.fill_(999.0)
|
||||
v_dst.fill_(999.0)
|
||||
|
||||
gathered_copy_kv(k_src, v_src, k_dst, v_dst, indices)
|
||||
|
||||
assert (k_dst[0] == 999.0).all()
|
||||
assert (v_dst[0] == 999.0).all()
|
||||
|
||||
|
||||
class TestPerformance:
|
||||
"""Performance benchmarks for gathered copy."""
|
||||
|
||||
@pytest.mark.parametrize("num_blocks", [8, 32, 128])
|
||||
def test_throughput(self, num_blocks):
|
||||
"""Benchmark copy throughput."""
|
||||
block_size = 256
|
||||
kv_dim = 64
|
||||
|
||||
src = torch.randn(num_blocks * 2, block_size, kv_dim,
|
||||
dtype=torch.float16, device="cuda")
|
||||
dst = torch.zeros(num_blocks, block_size, kv_dim,
|
||||
dtype=torch.float16, device="cuda")
|
||||
indices = torch.arange(num_blocks, dtype=torch.int64, device="cuda")
|
||||
|
||||
# Warmup
|
||||
for _ in range(10):
|
||||
gathered_copy(src, dst, indices)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Benchmark
|
||||
import time
|
||||
start = time.perf_counter()
|
||||
num_iters = 100
|
||||
for _ in range(num_iters):
|
||||
gathered_copy(src, dst, indices)
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
bytes_copied = num_blocks * block_size * kv_dim * 2 * num_iters # fp16
|
||||
bandwidth_gbps = bytes_copied / elapsed / 1e9
|
||||
|
||||
print(f"\n{num_blocks} blocks: {bandwidth_gbps:.2f} GB/s")
|
||||
|
||||
# Should achieve reasonable bandwidth (lower threshold for small blocks due to kernel launch overhead)
|
||||
min_bandwidth = 5 if num_blocks <= 16 else 10
|
||||
assert bandwidth_gbps > min_bandwidth, f"Bandwidth too low: {bandwidth_gbps} GB/s"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
175
tests/test_kvcache_manager.py
Normal file
175
tests/test_kvcache_manager.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""Tests for KV cache managers."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.kvcache.gpu_manager import GPUOnlyManager
|
||||
|
||||
|
||||
class MockSequence:
|
||||
"""Mock sequence for testing block allocation."""
|
||||
|
||||
def __init__(self, token_ids: list[int], block_size: int = 256):
|
||||
self._token_ids = token_ids
|
||||
self._block_size = block_size
|
||||
self.block_table: list[int] = []
|
||||
self.num_cached_tokens = 0
|
||||
|
||||
def __len__(self):
|
||||
return len(self._token_ids)
|
||||
|
||||
@property
|
||||
def num_blocks(self) -> int:
|
||||
return (len(self) + self._block_size - 1) // self._block_size
|
||||
|
||||
def block(self, i: int) -> list[int]:
|
||||
start = i * self._block_size
|
||||
end = min((i + 1) * self._block_size, len(self))
|
||||
return self._token_ids[start:end]
|
||||
|
||||
|
||||
class TestGPUOnlyManager:
|
||||
"""Tests for GPU-only KV cache manager."""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self):
|
||||
"""Create a small manager for testing."""
|
||||
return GPUOnlyManager(num_blocks=16, block_size=256)
|
||||
|
||||
def test_initialization(self, manager):
|
||||
"""Test manager initialization."""
|
||||
assert manager.block_size == 256
|
||||
assert manager.num_free_blocks == 16
|
||||
assert len(manager.blocks) == 16
|
||||
|
||||
def test_allocate_cache(self, manager):
|
||||
"""Test cache allocation."""
|
||||
manager.allocate_cache(
|
||||
num_layers=4,
|
||||
num_kv_heads=8,
|
||||
head_dim=64,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
assert manager.kv_cache is not None
|
||||
assert manager.kv_cache.shape == (2, 4, 16, 256, 8, 64)
|
||||
assert manager.kv_cache.device.type == "cuda"
|
||||
|
||||
def test_get_layer_cache(self, manager):
|
||||
"""Test getting layer cache."""
|
||||
manager.allocate_cache(
|
||||
num_layers=4,
|
||||
num_kv_heads=8,
|
||||
head_dim=64,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
k_cache, v_cache = manager.get_layer_cache(0)
|
||||
assert k_cache.shape == (16, 256, 8, 64)
|
||||
assert v_cache.shape == (16, 256, 8, 64)
|
||||
|
||||
def test_can_allocate(self, manager):
|
||||
"""Test allocation check."""
|
||||
seq = MockSequence([0] * 300) # Needs 2 blocks
|
||||
assert manager.can_allocate(seq)
|
||||
|
||||
# Fill up all blocks with unique tokens to avoid prefix caching
|
||||
for i in range(8):
|
||||
# Each sequence has unique tokens to prevent prefix cache hits
|
||||
s = MockSequence([i * 1000 + j for j in range(300)])
|
||||
manager.allocate(s)
|
||||
|
||||
# Now should not be able to allocate
|
||||
new_seq = MockSequence([9999] * 300)
|
||||
assert not manager.can_allocate(new_seq)
|
||||
|
||||
def test_allocate_and_deallocate(self, manager):
|
||||
"""Test block allocation and deallocation."""
|
||||
seq = MockSequence([0] * 600) # Needs 3 blocks
|
||||
initial_free = manager.num_free_blocks
|
||||
|
||||
manager.allocate(seq)
|
||||
assert len(seq.block_table) == 3
|
||||
assert manager.num_free_blocks == initial_free - 3
|
||||
|
||||
manager.deallocate(seq)
|
||||
assert len(seq.block_table) == 0
|
||||
assert manager.num_free_blocks == initial_free
|
||||
|
||||
def test_can_append(self, manager):
|
||||
"""Test append check."""
|
||||
seq = MockSequence([0] * 256) # Exactly 1 block
|
||||
manager.allocate(seq)
|
||||
|
||||
# Can append without new block (still in same block)
|
||||
seq._token_ids = [0] * 257
|
||||
assert manager.can_append(seq)
|
||||
|
||||
def test_prepare_for_attention_noop(self, manager):
|
||||
"""Test that prepare_for_attention is a no-op for GPU-only."""
|
||||
seq = MockSequence([0] * 100)
|
||||
manager.allocate(seq)
|
||||
|
||||
# Should not raise
|
||||
manager.prepare_for_attention([seq], is_prefill=True)
|
||||
manager.prepare_for_attention([seq], is_prefill=False)
|
||||
|
||||
def test_get_gpu_block_tables(self, manager):
|
||||
"""Test getting GPU block tables."""
|
||||
seq1 = MockSequence([0] * 300)
|
||||
seq2 = MockSequence([0] * 600)
|
||||
|
||||
manager.allocate(seq1)
|
||||
manager.allocate(seq2)
|
||||
|
||||
tables = manager.get_gpu_block_tables([seq1, seq2])
|
||||
|
||||
assert len(tables) == 2
|
||||
assert tables[0] == list(seq1.block_table)
|
||||
assert tables[1] == list(seq2.block_table)
|
||||
|
||||
|
||||
class TestGPUOnlyManagerPrefixCaching:
|
||||
"""Tests for prefix caching in GPU-only manager."""
|
||||
|
||||
@pytest.fixture
|
||||
def manager(self):
|
||||
"""Create manager for testing."""
|
||||
return GPUOnlyManager(num_blocks=32, block_size=256)
|
||||
|
||||
def test_prefix_cache_hit(self, manager):
|
||||
"""Test that identical prefixes are cached."""
|
||||
# Create two sequences with same prefix
|
||||
tokens = list(range(512)) # 2 full blocks
|
||||
seq1 = MockSequence(tokens)
|
||||
seq2 = MockSequence(tokens)
|
||||
|
||||
manager.allocate(seq1)
|
||||
initial_free = manager.num_free_blocks
|
||||
|
||||
manager.allocate(seq2)
|
||||
|
||||
# Second sequence should reuse cached blocks
|
||||
assert seq2.num_cached_tokens >= 256 # At least first block cached
|
||||
# Should use fewer new blocks
|
||||
assert manager.num_free_blocks >= initial_free - 2
|
||||
|
||||
def test_prefix_cache_different_suffix(self, manager):
|
||||
"""Test cache with same prefix but different suffix."""
|
||||
prefix = list(range(256)) # 1 full block
|
||||
|
||||
seq1 = MockSequence(prefix + [1000, 1001])
|
||||
seq2 = MockSequence(prefix + [2000, 2001])
|
||||
|
||||
manager.allocate(seq1)
|
||||
manager.allocate(seq2)
|
||||
|
||||
# First block should be shared
|
||||
assert seq1.block_table[0] == seq2.block_table[0]
|
||||
# Second block should be different
|
||||
assert seq1.block_table[1] != seq2.block_table[1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
196
tests/test_offload_engine.py
Normal file
196
tests/test_offload_engine.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""Tests for CPU-GPU offload engine."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||
|
||||
|
||||
class TestOffloadEngine:
|
||||
"""Tests for OffloadEngine."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self):
|
||||
"""Create a small engine for testing."""
|
||||
return OffloadEngine(
|
||||
num_layers=2,
|
||||
num_gpu_blocks=4,
|
||||
num_cpu_blocks=8,
|
||||
block_size=256,
|
||||
num_kv_heads=4,
|
||||
head_dim=64,
|
||||
dtype=torch.float16,
|
||||
num_streams=2,
|
||||
)
|
||||
|
||||
def test_initialization(self, engine):
|
||||
"""Test engine initialization."""
|
||||
# Check GPU cache shape
|
||||
assert engine.k_cache_gpu.shape == (2, 4, 256, 4, 64)
|
||||
assert engine.v_cache_gpu.shape == (2, 4, 256, 4, 64)
|
||||
|
||||
# Check CPU cache shape
|
||||
assert engine.k_cache_cpu.shape == (2, 8, 256, 4, 64)
|
||||
assert engine.v_cache_cpu.shape == (2, 8, 256, 4, 64)
|
||||
|
||||
# Check pinned memory
|
||||
assert engine.k_cache_cpu.is_pinned()
|
||||
assert engine.v_cache_cpu.is_pinned()
|
||||
|
||||
# Check gather indices
|
||||
assert engine.gather_indices_cpu.shape == (2, 4)
|
||||
assert engine.gather_indices_gpu.shape == (2, 4)
|
||||
|
||||
def test_get_layer_cache(self, engine):
|
||||
"""Test getting layer cache."""
|
||||
k, v = engine.get_layer_cache(0)
|
||||
assert k.shape == (4, 256, 4, 64)
|
||||
assert v.shape == (4, 256, 4, 64)
|
||||
assert k.device.type == "cuda"
|
||||
assert v.device.type == "cuda"
|
||||
|
||||
def test_prefetch_and_offload(self, engine):
|
||||
"""Test async prefetch and offload."""
|
||||
# Write some data to CPU block 0
|
||||
engine.k_cache_cpu[0, 0].fill_(1.0)
|
||||
engine.v_cache_cpu[0, 0].fill_(2.0)
|
||||
|
||||
# Prefetch to GPU block 2
|
||||
event = engine.prefetch_block_async(
|
||||
layer_id=0,
|
||||
cpu_block_id=0,
|
||||
gpu_block_id=2,
|
||||
)
|
||||
event.synchronize()
|
||||
|
||||
# Verify data was copied (move GPU to CPU for comparison)
|
||||
assert torch.allclose(engine.k_cache_gpu[0, 2].cpu(), engine.k_cache_cpu[0, 0])
|
||||
assert torch.allclose(engine.v_cache_gpu[0, 2].cpu(), engine.v_cache_cpu[0, 0])
|
||||
|
||||
# Modify GPU data
|
||||
engine.k_cache_gpu[0, 2].fill_(3.0)
|
||||
engine.v_cache_gpu[0, 2].fill_(4.0)
|
||||
|
||||
# Offload to CPU block 5
|
||||
event = engine.offload_block_async(
|
||||
layer_id=0,
|
||||
gpu_block_id=2,
|
||||
cpu_block_id=5,
|
||||
)
|
||||
event.synchronize()
|
||||
|
||||
# Verify data was copied
|
||||
assert torch.allclose(engine.k_cache_cpu[0, 5], engine.k_cache_gpu[0, 2].cpu())
|
||||
assert torch.allclose(engine.v_cache_cpu[0, 5], engine.v_cache_gpu[0, 2].cpu())
|
||||
|
||||
def test_update_gather_indices(self, engine):
|
||||
"""Test updating gather indices."""
|
||||
# Manually set CPU data
|
||||
for i in range(8):
|
||||
engine.k_cache_cpu[0, i].fill_(float(i))
|
||||
engine.v_cache_cpu[0, i].fill_(float(i + 100))
|
||||
|
||||
# Update indices for layer 0: (cpu_block_id, gpu_slot)
|
||||
mappings = [(2, 0), (5, 1), (1, 2), (7, 3)]
|
||||
engine.update_gather_indices(layer_id=0, mappings=mappings)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Verify indices were set
|
||||
expected = torch.tensor([2, 5, 1, 7], dtype=torch.int64)
|
||||
assert torch.equal(engine.gather_indices_cpu[0], expected)
|
||||
|
||||
def test_gathered_h2d_layer(self, engine):
|
||||
"""Test gathered H2D copy for a layer."""
|
||||
# Set up CPU data with known values
|
||||
for i in range(8):
|
||||
engine.k_cache_cpu[0, i].fill_(float(i))
|
||||
engine.v_cache_cpu[0, i].fill_(float(i + 100))
|
||||
|
||||
# Set gather indices: (cpu_block_id, gpu_slot)
|
||||
# GPU slot 0 gets CPU block 3, GPU slot 1 gets CPU block 0, etc.
|
||||
mappings = [(3, 0), (0, 1), (7, 2), (2, 3)]
|
||||
engine.update_gather_indices(layer_id=0, mappings=mappings)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Execute gathered H2D
|
||||
engine.gathered_h2d_layer(layer_id=0)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Verify: GPU slot 0 should have CPU block 3's data
|
||||
assert torch.allclose(engine.k_cache_gpu[0, 0],
|
||||
torch.full_like(engine.k_cache_gpu[0, 0], 3.0))
|
||||
# GPU slot 1 should have CPU block 0's data
|
||||
assert torch.allclose(engine.k_cache_gpu[0, 1],
|
||||
torch.full_like(engine.k_cache_gpu[0, 1], 0.0))
|
||||
# GPU slot 2 should have CPU block 7's data
|
||||
assert torch.allclose(engine.k_cache_gpu[0, 2],
|
||||
torch.full_like(engine.k_cache_gpu[0, 2], 7.0))
|
||||
# GPU slot 3 should have CPU block 2's data
|
||||
assert torch.allclose(engine.k_cache_gpu[0, 3],
|
||||
torch.full_like(engine.k_cache_gpu[0, 3], 2.0))
|
||||
|
||||
def test_multi_layer_independence(self, engine):
|
||||
"""Test that layers are independent."""
|
||||
# Set different data for each layer
|
||||
engine.k_cache_cpu[0, 0].fill_(1.0)
|
||||
engine.k_cache_cpu[1, 0].fill_(2.0)
|
||||
|
||||
# Prefetch layer 0
|
||||
event = engine.prefetch_block_async(0, 0, 0)
|
||||
event.synchronize()
|
||||
|
||||
# Verify only layer 0 was affected
|
||||
assert torch.allclose(engine.k_cache_gpu[0, 0],
|
||||
torch.full_like(engine.k_cache_gpu[0, 0], 1.0))
|
||||
# Layer 1 should be zeros (initial state)
|
||||
assert not torch.allclose(engine.k_cache_gpu[1, 0],
|
||||
torch.full_like(engine.k_cache_gpu[1, 0], 2.0))
|
||||
|
||||
|
||||
class TestOffloadEngineFixedAddresses:
|
||||
"""Tests verifying fixed address property for CUDA Graph compatibility."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine(self):
|
||||
"""Create engine for address tests."""
|
||||
return OffloadEngine(
|
||||
num_layers=2,
|
||||
num_gpu_blocks=4,
|
||||
num_cpu_blocks=8,
|
||||
block_size=256,
|
||||
num_kv_heads=4,
|
||||
head_dim=64,
|
||||
dtype=torch.float16,
|
||||
num_streams=2,
|
||||
)
|
||||
|
||||
def test_gpu_cache_address_fixed(self, engine):
|
||||
"""Verify GPU cache addresses don't change."""
|
||||
k_ptr_before = engine.k_cache_gpu.data_ptr()
|
||||
v_ptr_before = engine.v_cache_gpu.data_ptr()
|
||||
|
||||
# Perform some operations - mappings is List[(cpu_block_id, gpu_slot)]
|
||||
mappings = [(0, 0), (1, 1), (2, 2), (3, 3)]
|
||||
engine.update_gather_indices(0, mappings)
|
||||
engine.gathered_h2d_layer(0)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Addresses should be the same
|
||||
assert engine.k_cache_gpu.data_ptr() == k_ptr_before
|
||||
assert engine.v_cache_gpu.data_ptr() == v_ptr_before
|
||||
|
||||
def test_gather_indices_gpu_address_fixed(self, engine):
|
||||
"""Verify gather indices GPU tensor address doesn't change."""
|
||||
ptr_before = engine.gather_indices_gpu.data_ptr()
|
||||
|
||||
# Update indices multiple times - mappings is List[(cpu_block_id, gpu_slot)]
|
||||
mappings = [(0, 0), (1, 1), (2, 2), (3, 3)]
|
||||
for _ in range(10):
|
||||
engine.update_gather_indices(0, mappings)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
assert engine.gather_indices_gpu.data_ptr() == ptr_before
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
167
tests/test_policies.py
Normal file
167
tests/test_policies.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""Tests for eviction policies."""
|
||||
|
||||
import pytest
|
||||
from nanovllm.kvcache.policies.lru_policy import LRUPolicy
|
||||
from nanovllm.kvcache.policies.fifo_policy import FIFOPolicy
|
||||
from nanovllm.kvcache.policies import get_policy
|
||||
|
||||
|
||||
class TestLRUPolicy:
|
||||
"""Tests for LRU eviction policy."""
|
||||
|
||||
def test_basic_eviction(self):
|
||||
"""Test that LRU evicts least recently used block."""
|
||||
policy = LRUPolicy()
|
||||
|
||||
# Allocate blocks 0, 1, 2 in order
|
||||
policy.on_block_allocated(0, step=1)
|
||||
policy.on_block_allocated(1, step=2)
|
||||
policy.on_block_allocated(2, step=3)
|
||||
|
||||
# Access block 0 (makes it most recently used)
|
||||
policy.on_block_access(0, step=4)
|
||||
|
||||
# Should evict block 1 (least recently used)
|
||||
candidates = {0, 1, 2}
|
||||
victim = policy.select_victim(candidates)
|
||||
assert victim == 1, f"Expected block 1, got {victim}"
|
||||
|
||||
def test_access_updates_order(self):
|
||||
"""Test that access updates LRU order."""
|
||||
policy = LRUPolicy()
|
||||
|
||||
policy.on_block_allocated(0, step=1)
|
||||
policy.on_block_allocated(1, step=2)
|
||||
policy.on_block_allocated(2, step=3)
|
||||
|
||||
# Access all in reverse order
|
||||
policy.on_block_access(2, step=4)
|
||||
policy.on_block_access(1, step=5)
|
||||
policy.on_block_access(0, step=6)
|
||||
|
||||
# Block 2 is now LRU (accessed earliest after allocation update)
|
||||
candidates = {0, 1, 2}
|
||||
victim = policy.select_victim(candidates)
|
||||
assert victim == 2, f"Expected block 2, got {victim}"
|
||||
|
||||
def test_eviction_removes_from_tracking(self):
|
||||
"""Test that evicted blocks are removed from tracking."""
|
||||
policy = LRUPolicy()
|
||||
|
||||
policy.on_block_allocated(0, step=1)
|
||||
policy.on_block_allocated(1, step=2)
|
||||
|
||||
policy.on_block_evicted(0)
|
||||
|
||||
# Only block 1 should be a candidate
|
||||
candidates = {0, 1}
|
||||
victim = policy.select_victim(candidates)
|
||||
assert victim == 1, "Should select block 1 since 0 was evicted"
|
||||
|
||||
def test_batch_eviction_order(self):
|
||||
"""Test get_eviction_order returns blocks in LRU order."""
|
||||
policy = LRUPolicy()
|
||||
|
||||
for i in range(5):
|
||||
policy.on_block_allocated(i, step=i)
|
||||
|
||||
# Access blocks 2 and 4
|
||||
policy.on_block_access(2, step=10)
|
||||
policy.on_block_access(4, step=11)
|
||||
|
||||
candidates = {0, 1, 2, 3, 4}
|
||||
order = policy.get_eviction_order(candidates, count=3)
|
||||
|
||||
# Should be 0, 1, 3 (in that order, skipping 2 and 4 until needed)
|
||||
assert order == [0, 1, 3], f"Expected [0, 1, 3], got {order}"
|
||||
|
||||
|
||||
class TestFIFOPolicy:
|
||||
"""Tests for FIFO eviction policy."""
|
||||
|
||||
def test_basic_eviction(self):
|
||||
"""Test that FIFO evicts oldest allocated block."""
|
||||
policy = FIFOPolicy()
|
||||
|
||||
policy.on_block_allocated(0, step=1)
|
||||
policy.on_block_allocated(1, step=2)
|
||||
policy.on_block_allocated(2, step=3)
|
||||
|
||||
# Access doesn't change FIFO order
|
||||
policy.on_block_access(0, step=4)
|
||||
|
||||
candidates = {0, 1, 2}
|
||||
victim = policy.select_victim(candidates)
|
||||
assert victim == 0, f"Expected block 0 (oldest), got {victim}"
|
||||
|
||||
def test_access_does_not_update_order(self):
|
||||
"""Test that FIFO ignores access patterns."""
|
||||
policy = FIFOPolicy()
|
||||
|
||||
policy.on_block_allocated(0, step=1)
|
||||
policy.on_block_allocated(1, step=2)
|
||||
policy.on_block_allocated(2, step=3)
|
||||
|
||||
# Multiple accesses to block 0
|
||||
for i in range(10):
|
||||
policy.on_block_access(0, step=10 + i)
|
||||
|
||||
# Block 0 should still be evicted first (FIFO order)
|
||||
candidates = {0, 1, 2}
|
||||
victim = policy.select_victim(candidates)
|
||||
assert victim == 0, f"Expected block 0, got {victim}"
|
||||
|
||||
def test_prefetch_resets_order(self):
|
||||
"""Test that prefetch moves block to end of queue."""
|
||||
policy = FIFOPolicy()
|
||||
|
||||
policy.on_block_allocated(0, step=1)
|
||||
policy.on_block_allocated(1, step=2)
|
||||
policy.on_block_allocated(2, step=3)
|
||||
|
||||
# Prefetch block 0 (moves to end)
|
||||
policy.on_block_prefetched(0, step=4)
|
||||
|
||||
candidates = {0, 1, 2}
|
||||
victim = policy.select_victim(candidates)
|
||||
assert victim == 1, f"Expected block 1 (now oldest), got {victim}"
|
||||
|
||||
def test_batch_eviction_order(self):
|
||||
"""Test get_eviction_order returns blocks in FIFO order."""
|
||||
policy = FIFOPolicy()
|
||||
|
||||
for i in range(5):
|
||||
policy.on_block_allocated(i, step=i)
|
||||
|
||||
candidates = {0, 1, 2, 3, 4}
|
||||
order = policy.get_eviction_order(candidates, count=3)
|
||||
|
||||
assert order == [0, 1, 2], f"Expected [0, 1, 2], got {order}"
|
||||
|
||||
|
||||
class TestGetPolicy:
|
||||
"""Tests for policy factory function."""
|
||||
|
||||
def test_get_lru(self):
|
||||
"""Test getting LRU policy by name."""
|
||||
policy = get_policy("lru")
|
||||
assert isinstance(policy, LRUPolicy)
|
||||
|
||||
def test_get_fifo(self):
|
||||
"""Test getting FIFO policy by name."""
|
||||
policy = get_policy("fifo")
|
||||
assert isinstance(policy, FIFOPolicy)
|
||||
|
||||
def test_get_by_class_path(self):
|
||||
"""Test getting policy by full class path."""
|
||||
policy = get_policy("nanovllm.kvcache.policies.lru_policy.LRUPolicy")
|
||||
assert isinstance(policy, LRUPolicy)
|
||||
|
||||
def test_invalid_policy_name(self):
|
||||
"""Test that invalid policy name raises error."""
|
||||
with pytest.raises((ValueError, ImportError)):
|
||||
get_policy("invalid_policy")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user