diff --git a/CLAUDE.md b/CLAUDE.md index e236d36..7416e8b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -237,7 +237,6 @@ Warmup uses a reasonable sequence length (`block_size * 2`) instead of `max_mode | `max_num_seqs` | 512 | Max concurrent sequences | | `gpu_memory_utilization` | 0.9 | GPU memory fraction for KV cache | | `enforce_eager` | False | Disable CUDA graphs if True | -| `num_prefetch_blocks` | 2 | Ring buffer pipeline depth (deprecated, uses num_gpu_blocks) | ## Benchmarking diff --git a/bench_offload.py b/bench_offload.py index 2863a11..f3b26ff 100644 --- a/bench_offload.py +++ b/bench_offload.py @@ -109,7 +109,6 @@ def main(): max_num_batched_tokens=max_len, enable_cpu_offload=True, num_gpu_blocks=8, # Small GPU buffer for offload testing - num_prefetch_blocks=4, ) if not args.no_sparse: diff --git a/nanovllm/config.py b/nanovllm/config.py index da2aee5..e59a5eb 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -22,7 +22,6 @@ class Config: offload_policy: str = "lru" # "lru", "fifo", or full class path num_transfer_streams: int = 4 # Number of CUDA streams for async transfers num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available) - num_prefetch_blocks: int = 2 # Number of prefetch blocks for three-region GPU buffer design # Computed fields for offload (set in __post_init__ or by ModelRunner) num_gpu_kvcache_blocks: int = -1 diff --git a/nanovllm/kvcache/__init__.py b/nanovllm/kvcache/__init__.py index e8eb7f9..02de400 100644 --- a/nanovllm/kvcache/__init__.py +++ b/nanovllm/kvcache/__init__.py @@ -58,14 +58,12 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager: from nanovllm.kvcache.policies import get_policy policy = get_policy(getattr(config, 'offload_policy', 'lru')) - num_prefetch_blocks = getattr(config, 'num_prefetch_blocks', 2) return HybridKVCacheManager( num_gpu_slots=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks, block_size=config.kvcache_block_size, policy=policy, - num_prefetch_blocks=num_prefetch_blocks, ) diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index 5be9e94..a4004b6 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -86,7 +86,6 @@ class HybridKVCacheManager(KVCacheManager): num_cpu_blocks: int, block_size: int, policy: Optional[EvictionPolicy] = None, - num_prefetch_blocks: int = 2, ): """ Initialize hybrid manager with CPU-primary ring buffer design. @@ -99,13 +98,11 @@ class HybridKVCacheManager(KVCacheManager): num_cpu_blocks: Number of CPU pool blocks (primary storage) block_size: Tokens per block policy: Eviction policy (default: LRU, used for prefix cache management) - num_prefetch_blocks: Number of blocks for ring buffer pipeline (deprecated, ring_slots = num_gpu_slots) """ self._block_size = block_size self.num_gpu_slots = num_gpu_slots self.num_cpu_blocks = num_cpu_blocks self.total_blocks = num_gpu_slots + num_cpu_blocks - self.num_prefetch_blocks = num_prefetch_blocks # Ring buffer design parameter (deprecated) # Eviction policy self.policy = policy or LRUPolicy() @@ -170,7 +167,6 @@ class HybridKVCacheManager(KVCacheManager): num_kv_heads=num_kv_heads, head_dim=head_dim, dtype=dtype, - num_prefetch_blocks=self.num_prefetch_blocks, ) def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]: diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index c0946e0..f5cea8a 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -53,7 +53,6 @@ class OffloadEngine: head_dim: int, dtype: torch.dtype = torch.float16, num_streams: int = 4, - num_prefetch_blocks: int = 2, ): self.num_layers = num_layers self.num_gpu_blocks = num_gpu_blocks @@ -82,8 +81,6 @@ class OffloadEngine: self.decode_load_slots = list(range(1, num_gpu_blocks)) self.num_decode_load_slots = len(self.decode_load_slots) - # Keep num_prefetch_blocks for compatibility (used as chunk size for loading) - self.num_prefetch_blocks = num_prefetch_blocks self.num_gpu_slots = num_gpu_blocks # alias logger.info(f"Unified Ring Buffer: {self.num_ring_slots} slots total") diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index ef6c1f5..adb8546 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -378,9 +378,9 @@ class Attention(nn.Module): offload_engine = kvcache_manager.offload_engine - # Use prefetch_size as chunk size for double buffering - # This ensures both Compute and Prefetch regions can hold a full chunk - chunk_size = offload_engine.num_prefetch_blocks + # Chunk size = capacity of each double buffer region (compute/prefetch) + # Each region uses half of decode_load_slots + chunk_size = max(1, len(offload_engine.decode_load_slots) // 2) num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size o_acc = None diff --git a/tests/test_chunked_attention.py b/tests/test_chunked_attention.py new file mode 100644 index 0000000..48d114c --- /dev/null +++ b/tests/test_chunked_attention.py @@ -0,0 +1,169 @@ +""" +Test script for chunked attention correctness. + +Validates that chunked prefill using flash_attn_with_lse + merge_attention_outputs +produces the same result as full flash_attn_varlen_func. + +Scenario: Simulating chunked prefill where we process query chunk by chunk. +For each query chunk i: +- KV contains all tokens from chunk 0 to chunk i +- Previous KV chunks (0 to i-1): full attention (no causal mask) +- Current KV chunk (i): causal attention (diagonal block) +""" + +import torch +from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_func +from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + +# ============================================================ +# Utility Functions +# ============================================================ + +def compute_chunked_prefill_for_chunk( + q_chunk: torch.Tensor, + kv_chunks: list, + current_chunk_idx: int, +) -> torch.Tensor: + """ + Compute attention for a single query chunk against all KV chunks up to current. + + This simulates chunked prefill for query chunk `current_chunk_idx`: + - KV chunks 0 to current_chunk_idx-1: full attention (all previous tokens visible) + - KV chunk current_chunk_idx: causal attention (diagonal block) + + Args: + q_chunk: [batch, chunk_size, nheads, headdim] - current query chunk + kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim] + current_chunk_idx: Index of the current chunk being processed + + Returns: + out: [batch, chunk_size, nheads, headdim] + """ + accumulated_o = None + accumulated_lse = None + + for i in range(current_chunk_idx + 1): + k_chunk, v_chunk = kv_chunks[i] + + # Previous chunks: no causal mask (all tokens visible) + # Current chunk (diagonal): causal mask + is_diagonal = (i == current_chunk_idx) + + chunk_o, chunk_lse = flash_attn_with_lse( + q_chunk, k_chunk, v_chunk, causal=is_diagonal + ) + + if accumulated_o is None: + accumulated_o = chunk_o + accumulated_lse = chunk_lse + else: + accumulated_o, accumulated_lse = merge_attention_outputs( + accumulated_o, accumulated_lse, + chunk_o, chunk_lse + ) + + return accumulated_o + + +def compute_reference_causal( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, +) -> torch.Tensor: + """ + Compute reference causal attention using flash_attn_func. + + Args: + q, k, v: [batch, seqlen, nheads, headdim] + + Returns: + out: [batch, seqlen, nheads, headdim] + """ + return flash_attn_func(q, k, v, causal=True) + + +# ============================================================ +# Main Test Script +# ============================================================ + +torch.manual_seed(42) + +# Test configurations: (batch, num_chunks, chunk_size, nheads, headdim) +TEST_CASES = [ + (1, 4, 256, 8, 128), + (1, 4, 512, 8, 128), + (1, 8, 512, 8, 128), + (1, 4, 1024, 8, 128), + (1, 4, 1024, 32, 128), # More heads + (1, 8, 256, 8, 64), # Smaller head dim +] + +DTYPES = [torch.float16, torch.bfloat16] + +print("=" * 80) +print("Test: Chunked Prefill Attention vs Reference (flash_attn_func causal)") +print("=" * 80) +print("Simulating chunked prefill: Q chunk attends to all KV chunks up to current") +print(" - Previous KV chunks: full attention (no causal mask)") +print(" - Current KV chunk (diagonal): causal attention") +print() + +all_passed = True + +for dtype in DTYPES: + print(f"--- dtype: {dtype} ---") + + for batch, num_chunks, chunk_size, nheads, headdim in TEST_CASES: + seqlen = num_chunks * chunk_size + + # Generate full Q, K, V + q_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) + k_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) + v_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) + + # Reference: full causal attention + out_ref = compute_reference_causal(q_full, k_full, v_full) + + # Split into chunks + q_chunks = [q_full[:, i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)] + kv_chunks = [ + (k_full[:, i*chunk_size:(i+1)*chunk_size], + v_full[:, i*chunk_size:(i+1)*chunk_size]) + for i in range(num_chunks) + ] + + # Compute chunked prefill for each query chunk + out_chunks = [] + for chunk_idx in range(num_chunks): + chunk_out = compute_chunked_prefill_for_chunk( + q_chunks[chunk_idx], + kv_chunks, + chunk_idx, + ) + out_chunks.append(chunk_out) + + # Concatenate chunked outputs + out_chunked = torch.cat(out_chunks, dim=1) + + # Compare + diff = (out_ref - out_chunked).abs() + max_diff = diff.max().item() + mean_diff = diff.mean().item() + + # Tolerance: fp16/bf16 have limited precision + tol = 1e-2 + passed = max_diff < tol + all_passed = all_passed and passed + + status = "PASS" if passed else "FAIL" + print( + f"[{status}] seqlen={seqlen:5d} chunks={num_chunks} " + f"chunk_size={chunk_size:4d} heads={nheads:2d} dim={headdim:3d} " + f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}" + ) + + print() + +print("=" * 80) +assert all_passed, "Some tests failed!" +print("test_chunked_attention: PASSED") diff --git a/tests/test_prefill.py b/tests/test_prefill.py index b955520..9e501c0 100644 --- a/tests/test_prefill.py +++ b/tests/test_prefill.py @@ -5,17 +5,20 @@ Demonstrates: LLM initialization, prefill execution with CPU offload enabled. """ import os +os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG" + from random import randint, seed from nanovllm import LLM, SamplingParams + # ============================================================ # Configuration # ============================================================ MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") -MAX_MODEL_LEN = 8192 -NUM_GPU_BLOCKS = 4 -INPUT_LEN = 4096 +MAX_MODEL_LEN = 32 * 1024 +NUM_GPU_BLOCKS = 2 +INPUT_LEN = 16 * 1024 # ============================================================ # Main Test Script @@ -28,6 +31,7 @@ llm = LLM( max_model_len=MAX_MODEL_LEN, max_num_batched_tokens=MAX_MODEL_LEN, enable_cpu_offload=True, + kvcache_block_size=1024, num_gpu_blocks=NUM_GPU_BLOCKS, ) diff --git a/tests/test_sim.py b/tests/test_sim.py new file mode 100644 index 0000000..a113ddc --- /dev/null +++ b/tests/test_sim.py @@ -0,0 +1,286 @@ +""" +Chunked Prefill + KV Cache Offload Simulation v2 + +改进: +1. 简化日志输出 +2. 添加reduce时间 +3. 计算必须等待KV load完成 +""" + +import threading +import time +from dataclasses import dataclass +from typing import Optional +from concurrent.futures import ThreadPoolExecutor, Future + +# ============== 配置参数 ============== +NUM_CHUNKS = 8 +GPU_SLOTS = 4 + +# 模拟时间 (秒) +TIME_COMPUTE_BLOCK = 0.10 # 计算一个attention block +TIME_REDUCE = 0.03 # 两个partial result做一次reduce +TIME_TRANSFER = 0.08 # 传输一个KV cache +TIME_PROJ = 0.02 # projection生成KV + +# ============== 全局时间基准 ============== +START_TIME = None + +def now() -> float: + """返回相对开始的时间(ms)""" + return (time.time() - START_TIME) * 1000 + +def log_compute(msg: str): + """计算队列日志(无缩进)""" + print(f"[{now():7.1f}ms] [COMPUTE] {msg}") + +def log_transfer(msg: str): + """传输队列日志(缩进)""" + print(f"[{now():7.1f}ms] [TRANSFER] {msg}") + +def log_info(msg: str): + """一般信息""" + print(f"[{now():7.1f}ms] {msg}") + +# ============== GPU Slot管理 ============== +class GPUSlots: + def __init__(self, num_slots: int): + self.slots = [None] * num_slots # slot_id -> kv_idx + self.kv_to_slot = {} # kv_idx -> slot_id + self.lock = threading.Lock() + # KV ready events: kv_idx -> Event + self.kv_ready = {} + + def alloc(self, kv_idx: int) -> int: + with self.lock: + for sid, val in enumerate(self.slots): + if val is None: + self.slots[sid] = kv_idx + self.kv_to_slot[kv_idx] = sid + # 创建ready event + if kv_idx not in self.kv_ready: + self.kv_ready[kv_idx] = threading.Event() + return sid + raise RuntimeError(f"No free slot for KV{kv_idx}") + + def free(self, slot_id: int): + with self.lock: + kv_idx = self.slots[slot_id] + if kv_idx is not None: + del self.kv_to_slot[kv_idx] + # 清除event + if kv_idx in self.kv_ready: + del self.kv_ready[kv_idx] + self.slots[slot_id] = None + + def free_kv(self, kv_idx: int): + with self.lock: + if kv_idx in self.kv_to_slot: + sid = self.kv_to_slot[kv_idx] + self.slots[sid] = None + del self.kv_to_slot[kv_idx] + if kv_idx in self.kv_ready: + del self.kv_ready[kv_idx] + + def mark_ready(self, kv_idx: int): + """标记KV已就绪(load完成或proj完成)""" + with self.lock: + if kv_idx in self.kv_ready: + self.kv_ready[kv_idx].set() + + def wait_ready(self, kv_idx: int): + """等待KV就绪""" + with self.lock: + event = self.kv_ready.get(kv_idx) + if event: + event.wait() + + def has_kv(self, kv_idx: int) -> bool: + with self.lock: + return kv_idx in self.kv_to_slot + + def state(self) -> str: + with self.lock: + return "[" + "][".join( + f"KV{v}" if v is not None else "----" + for v in self.slots + ) + "]" + +# ============== 操作执行 ============== +class Executor: + def __init__(self, gpu: GPUSlots): + self.gpu = gpu + self.compute_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Compute") + self.transfer_pool = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Transfer") + + def proj_kv(self, q_idx: int) -> Future: + """Projection生成KV,返回Future""" + def task(): + log_compute(f"PROJ Q{q_idx}->KV{q_idx} START") + time.sleep(TIME_PROJ) + slot_id = self.gpu.alloc(q_idx) + self.gpu.mark_ready(q_idx) # proj完成,KV立即可用 + log_compute(f"PROJ Q{q_idx}->KV{q_idx} END slot={slot_id} | {self.gpu.state()}") + return slot_id + return self.compute_pool.submit(task) + + def compute_attn(self, q_idx: int, kv_indices: list) -> Future: + """计算attention block,会等待所有KV就绪""" + def task(): + # 等待所有需要的KV就绪 + for kv_idx in kv_indices: + self.gpu.wait_ready(kv_idx) + + kv_str = ",".join(map(str, kv_indices)) + log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] START") + time.sleep(TIME_COMPUTE_BLOCK * len(kv_indices)) + log_compute(f"ATTN Q{q_idx}*KV[{kv_str}] END") + return (q_idx, kv_indices) + return self.compute_pool.submit(task) + + def reduce(self, q_idx: int, num_partials: int) -> Future: + """Online softmax reduce多个partial结果""" + def task(): + if num_partials <= 1: + return + # n个partial需要n-1次两两reduce + num_reduces = num_partials - 1 + log_compute(f"REDUCE Q{q_idx} ({num_partials} partials) START") + time.sleep(TIME_REDUCE * num_reduces) + log_compute(f"REDUCE Q{q_idx} END") + return self.compute_pool.submit(task) + + def load_kv(self, kv_idx: int) -> Future: + """从CPU load KV到GPU""" + def task(): + if self.gpu.has_kv(kv_idx): + log_transfer(f"LOAD KV{kv_idx} SKIP (already on GPU)") + return kv_idx + + slot_id = self.gpu.alloc(kv_idx) + log_transfer(f"LOAD KV{kv_idx} START -> slot{slot_id}") + time.sleep(TIME_TRANSFER) + self.gpu.mark_ready(kv_idx) # load完成,标记就绪 + log_transfer(f"LOAD KV{kv_idx} END | {self.gpu.state()}") + return kv_idx + return self.transfer_pool.submit(task) + + def offload_kv(self, kv_idx: int) -> Future: + """从GPU offload KV到CPU""" + def task(): + log_transfer(f"OFFLOAD KV{kv_idx} START") + time.sleep(TIME_TRANSFER) + self.gpu.free_kv(kv_idx) + log_transfer(f"OFFLOAD KV{kv_idx} END | {self.gpu.state()}") + return kv_idx + return self.transfer_pool.submit(task) + + def shutdown(self): + self.compute_pool.shutdown(wait=True) + self.transfer_pool.shutdown(wait=True) + +# ============== 调度器 ============== +def schedule_query(exe: Executor, q_idx: int): + """调度单个Query的处理""" + print(f"\n{'='*50}") + log_info(f"===== Query {q_idx} START =====") + + hist_kv = list(range(q_idx)) # 历史KV: 0 ~ q_idx-1 + num_partials = 0 + + # Phase 1: Projection生成当前KV + proj_fut = exe.proj_kv(q_idx) + proj_fut.result() # 等待完成 + + # Phase 2: 对角块计算 + 同时prefetch历史KV + # 启动对角块计算 + diag_fut = exe.compute_attn(q_idx, [q_idx]) + num_partials += 1 + + # 同时prefetch历史KV (最多3个slot可用) + prefetch_slots = min(len(hist_kv), GPU_SLOTS - 1) + prefetch_kv = hist_kv[:prefetch_slots] + prefetch_futs = [exe.load_kv(kv) for kv in prefetch_kv] + + # 等待对角块完成 + diag_fut.result() + + # Phase 3: Offload当前KV释放slot + offload_fut = exe.offload_kv(q_idx) + + # 等待prefetch完成,然后计算这批历史KV + for f in prefetch_futs: + f.result() + + if prefetch_kv: + hist_fut = exe.compute_attn(q_idx, prefetch_kv) + num_partials += 1 + else: + hist_fut = None + + # 等待offload完成 + offload_fut.result() + + # Phase 4: 处理剩余历史KV + remaining_kv = hist_kv[prefetch_slots:] + computed_kv = prefetch_kv.copy() + + while remaining_kv: + # 等待上一批计算完成 + if hist_fut: + hist_fut.result() + + # 释放已计算的KV + for kv in computed_kv: + exe.gpu.free_kv(kv) + + # Load下一批 + batch_size = min(len(remaining_kv), GPU_SLOTS) + batch_kv = remaining_kv[:batch_size] + remaining_kv = remaining_kv[batch_size:] + + load_futs = [exe.load_kv(kv) for kv in batch_kv] + for f in load_futs: + f.result() + + # 计算这批 + hist_fut = exe.compute_attn(q_idx, batch_kv) + num_partials += 1 + computed_kv = batch_kv + + # 等待最后一批计算完成 + if hist_fut: + hist_fut.result() + + # 清理GPU + for kv in computed_kv: + exe.gpu.free_kv(kv) + + # Phase 5: Reduce所有partial results + reduce_fut = exe.reduce(q_idx, num_partials) + reduce_fut.result() + + log_info(f"===== Query {q_idx} END =====") + +def main(): + global START_TIME + START_TIME = time.time() + + print("Chunked Prefill + KV Cache Offload Simulation v2") + print(f"Config: {NUM_CHUNKS} chunks, {GPU_SLOTS} GPU slots") + print(f"Time: compute={TIME_COMPUTE_BLOCK}s, transfer={TIME_TRANSFER}s, reduce={TIME_REDUCE}s") + + gpu = GPUSlots(GPU_SLOTS) + exe = Executor(gpu) + + try: + for q_idx in range(NUM_CHUNKS): + schedule_query(exe, q_idx) + + print(f"\n{'='*50}") + log_info(f"ALL DONE! Total: {now():.1f}ms") + finally: + exe.shutdown() + +if __name__ == "__main__": + main() \ No newline at end of file