From 761929390eee484bd6bada26789776869fed0088 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 10 Dec 2025 00:44:57 +0800 Subject: [PATCH] [bench] Added vllm vs nano-vllm bench. --- CLAUDE.md | 73 ++++++++++++++++++++++++++++++++++++ bench.py | 64 ++++++++++++++++++++++--------- bench_vllm.py | 63 +++++++++++++++++++++++++++++++ nanovllm/layers/attention.py | 2 +- 4 files changed, 183 insertions(+), 19 deletions(-) create mode 100644 CLAUDE.md create mode 100644 bench_vllm.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..adcd99b --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,73 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Currently supports Qwen3 models. + +## Commands + +```bash +# Install +pip install -e . + +# Run example +python example.py + +# Run benchmark +python bench.py +``` + +## Architecture + +### Core Components + +**LLMEngine** (`nanovllm/engine/llm_engine.py`): +- Main entry point, wraps ModelRunner and Scheduler +- Handles tokenization and multi-process tensor parallelism coordination +- `generate()` method runs the prefill-decode loop until all sequences finish + +**ModelRunner** (`nanovllm/engine/model_runner.py`): +- Loads model weights, allocates KV cache, captures CUDA graphs +- Rank 0 is the main process; ranks 1+ run in separate processes via `loop()` waiting on shared memory events +- `run()` prepares inputs and executes model forward pass + +**Scheduler** (`nanovllm/engine/scheduler.py`): +- Two-phase scheduling: prefill (waiting queue) then decode (running queue) +- Handles preemption when memory is constrained by moving sequences back to waiting + +**BlockManager** (`nanovllm/engine/block_manager.py`): +- Paged attention block allocation with prefix caching via xxhash +- Blocks are 256 tokens by default, tracked with reference counting + +**Sequence** (`nanovllm/engine/sequence.py`): +- Tracks token IDs, block table, and sampling parameters per request +- Custom `__getstate__`/`__setstate__` for efficient pickling across processes + +### Model Implementation + +**Qwen3ForCausalLM** (`nanovllm/models/qwen3.py`): +- Standard transformer: embedding → decoder layers → RMSNorm → LM head +- Uses `packed_modules_mapping` for weight loading (q/k/v → qkv_proj, gate/up → gate_up_proj) + +**Attention** (`nanovllm/layers/attention.py`): +- Uses FlashAttention (`flash_attn_varlen_func` for prefill, `flash_attn_with_kvcache` for decode) +- Custom Triton kernel `store_kvcache_kernel` for KV cache writes + +**Parallel Layers** (`nanovllm/layers/linear.py`, `embed_head.py`): +- Tensor parallelism via column/row parallel linear layers with custom weight loaders + +### Key Design Patterns + +- **Global Context**: `nanovllm/utils/context.py` stores attention metadata (cu_seqlens, slot_mapping, block_tables) accessed via `get_context()`/`set_context()` +- **CUDA Graph Capture**: Decode phase uses captured graphs for batch sizes 1, 2, 4, 8, 16, 32... up to max_num_seqs (capped at 512) +- **Shared Memory IPC**: Tensor parallel workers receive commands via pickled data in SharedMemory, synchronized with Events + +### Config Defaults + +- `max_num_batched_tokens`: 16384 +- `max_num_seqs`: 512 +- `kvcache_block_size`: 256 +- `gpu_memory_utilization`: 0.9 +- `enforce_eager`: False (enables CUDA graphs) diff --git a/bench.py b/bench.py index 8e61d65..535251f 100644 --- a/bench.py +++ b/bench.py @@ -2,30 +2,58 @@ import os import time from random import randint, seed from nanovllm import LLM, SamplingParams -# from vllm import LLM, SamplingParams + + +def bench_decode(llm, num_seqs, max_input_len, max_output_len): + """Benchmark decode performance (original test)""" + seed(0) + prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)] + sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)] + + t = time.time() + llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + t = time.time() - t + total_output_tokens = sum(sp.max_tokens for sp in sampling_params) + throughput = total_output_tokens / t + print(f"[Decode] Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + + +def bench_prefill(llm, num_seqs, input_len): + """Benchmark prefill performance""" + seed(0) + # Fixed length input, minimal output to focus on prefill + prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)] + sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1) + + t = time.time() + llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + t = time.time() - t + total_input_tokens = num_seqs * input_len + throughput = total_input_tokens / t + print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") def main(): - seed(0) - num_seqs = 256 - max_input_len = 1024 - max_ouput_len = 1024 - - path = os.path.expanduser("~/huggingface/Qwen3-0.6B/") + path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") llm = LLM(path, enforce_eager=False, max_model_len=4096) - prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)] - sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)] - # uncomment the following line for vllm - # prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids] - + # Warmup llm.generate(["Benchmark: "], SamplingParams()) - t = time.time() - llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) - t = (time.time() - t) - total_tokens = sum(sp.max_tokens for sp in sampling_params) - throughput = total_tokens / t - print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + + print("=" * 60) + print("Prefill Benchmark") + print("=" * 60) + bench_prefill(llm, num_seqs=1, input_len=1024) + bench_prefill(llm, num_seqs=1, input_len=2048) + bench_prefill(llm, num_seqs=1, input_len=4095) + bench_prefill(llm, num_seqs=16, input_len=1024) + bench_prefill(llm, num_seqs=64, input_len=1024) + + print("=" * 60) + print("Decode Benchmark") + print("=" * 60) + bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024) + bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024) if __name__ == "__main__": diff --git a/bench_vllm.py b/bench_vllm.py new file mode 100644 index 0000000..ff5789a --- /dev/null +++ b/bench_vllm.py @@ -0,0 +1,63 @@ +import os +os.environ["VLLM_USE_V1"] = "1" +import time +from random import randint, seed +from vllm import LLM, SamplingParams + + +def bench_decode(llm, num_seqs, max_input_len, max_output_len): + """Benchmark decode performance (original test)""" + seed(0) + prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)] + sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)] + prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids] + + t = time.time() + llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + t = time.time() - t + total_output_tokens = sum(sp.max_tokens for sp in sampling_params) + throughput = total_output_tokens / t + print(f"[Decode] Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + + +def bench_prefill(llm, num_seqs, input_len): + """Benchmark prefill performance""" + seed(0) + # Fixed length input, minimal output to focus on prefill + prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)] + sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1) + prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids] + + t = time.time() + llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + t = time.time() - t + total_input_tokens = num_seqs * input_len + throughput = total_input_tokens / t + print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + + +def main(): + path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") + llm = LLM(path, enforce_eager=False, max_model_len=4096, max_num_seqs=128, gpu_memory_utilization=0.9) + + # Warmup + llm.generate([dict(prompt_token_ids=[0])], SamplingParams()) + + print("=" * 60) + print("Prefill Benchmark") + print("=" * 60) + bench_prefill(llm, num_seqs=1, input_len=1024) + bench_prefill(llm, num_seqs=1, input_len=2048) + bench_prefill(llm, num_seqs=1, input_len=4095) + bench_prefill(llm, num_seqs=16, input_len=1024) + bench_prefill(llm, num_seqs=64, input_len=1024) + + print("=" * 60) + print("Decode Benchmark") + print("=" * 60) + bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024) + bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024) + + +if __name__ == "__main__": + main() diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index e416139..f5046b3 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -3,7 +3,7 @@ from torch import nn import triton import triton.language as tl -from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache from nanovllm.utils.context import get_context