[bench] Added vllm vs nano-vllm bench.
This commit is contained in:
73
CLAUDE.md
Normal file
73
CLAUDE.md
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Currently supports Qwen3 models.
|
||||||
|
|
||||||
|
## Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Install
|
||||||
|
pip install -e .
|
||||||
|
|
||||||
|
# Run example
|
||||||
|
python example.py
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
python bench.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Core Components
|
||||||
|
|
||||||
|
**LLMEngine** (`nanovllm/engine/llm_engine.py`):
|
||||||
|
- Main entry point, wraps ModelRunner and Scheduler
|
||||||
|
- Handles tokenization and multi-process tensor parallelism coordination
|
||||||
|
- `generate()` method runs the prefill-decode loop until all sequences finish
|
||||||
|
|
||||||
|
**ModelRunner** (`nanovllm/engine/model_runner.py`):
|
||||||
|
- Loads model weights, allocates KV cache, captures CUDA graphs
|
||||||
|
- Rank 0 is the main process; ranks 1+ run in separate processes via `loop()` waiting on shared memory events
|
||||||
|
- `run()` prepares inputs and executes model forward pass
|
||||||
|
|
||||||
|
**Scheduler** (`nanovllm/engine/scheduler.py`):
|
||||||
|
- Two-phase scheduling: prefill (waiting queue) then decode (running queue)
|
||||||
|
- Handles preemption when memory is constrained by moving sequences back to waiting
|
||||||
|
|
||||||
|
**BlockManager** (`nanovllm/engine/block_manager.py`):
|
||||||
|
- Paged attention block allocation with prefix caching via xxhash
|
||||||
|
- Blocks are 256 tokens by default, tracked with reference counting
|
||||||
|
|
||||||
|
**Sequence** (`nanovllm/engine/sequence.py`):
|
||||||
|
- Tracks token IDs, block table, and sampling parameters per request
|
||||||
|
- Custom `__getstate__`/`__setstate__` for efficient pickling across processes
|
||||||
|
|
||||||
|
### Model Implementation
|
||||||
|
|
||||||
|
**Qwen3ForCausalLM** (`nanovllm/models/qwen3.py`):
|
||||||
|
- Standard transformer: embedding → decoder layers → RMSNorm → LM head
|
||||||
|
- Uses `packed_modules_mapping` for weight loading (q/k/v → qkv_proj, gate/up → gate_up_proj)
|
||||||
|
|
||||||
|
**Attention** (`nanovllm/layers/attention.py`):
|
||||||
|
- Uses FlashAttention (`flash_attn_varlen_func` for prefill, `flash_attn_with_kvcache` for decode)
|
||||||
|
- Custom Triton kernel `store_kvcache_kernel` for KV cache writes
|
||||||
|
|
||||||
|
**Parallel Layers** (`nanovllm/layers/linear.py`, `embed_head.py`):
|
||||||
|
- Tensor parallelism via column/row parallel linear layers with custom weight loaders
|
||||||
|
|
||||||
|
### Key Design Patterns
|
||||||
|
|
||||||
|
- **Global Context**: `nanovllm/utils/context.py` stores attention metadata (cu_seqlens, slot_mapping, block_tables) accessed via `get_context()`/`set_context()`
|
||||||
|
- **CUDA Graph Capture**: Decode phase uses captured graphs for batch sizes 1, 2, 4, 8, 16, 32... up to max_num_seqs (capped at 512)
|
||||||
|
- **Shared Memory IPC**: Tensor parallel workers receive commands via pickled data in SharedMemory, synchronized with Events
|
||||||
|
|
||||||
|
### Config Defaults
|
||||||
|
|
||||||
|
- `max_num_batched_tokens`: 16384
|
||||||
|
- `max_num_seqs`: 512
|
||||||
|
- `kvcache_block_size`: 256
|
||||||
|
- `gpu_memory_utilization`: 0.9
|
||||||
|
- `enforce_eager`: False (enables CUDA graphs)
|
||||||
64
bench.py
64
bench.py
@@ -2,30 +2,58 @@ import os
|
|||||||
import time
|
import time
|
||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
# from vllm import LLM, SamplingParams
|
|
||||||
|
|
||||||
|
def bench_decode(llm, num_seqs, max_input_len, max_output_len):
|
||||||
|
"""Benchmark decode performance (original test)"""
|
||||||
|
seed(0)
|
||||||
|
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
||||||
|
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
|
t = time.time() - t
|
||||||
|
total_output_tokens = sum(sp.max_tokens for sp in sampling_params)
|
||||||
|
throughput = total_output_tokens / t
|
||||||
|
print(f"[Decode] Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||||
|
|
||||||
|
|
||||||
|
def bench_prefill(llm, num_seqs, input_len):
|
||||||
|
"""Benchmark prefill performance"""
|
||||||
|
seed(0)
|
||||||
|
# Fixed length input, minimal output to focus on prefill
|
||||||
|
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||||
|
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
|
t = time.time() - t
|
||||||
|
total_input_tokens = num_seqs * input_len
|
||||||
|
throughput = total_input_tokens / t
|
||||||
|
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
seed(0)
|
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
num_seqs = 256
|
|
||||||
max_input_len = 1024
|
|
||||||
max_ouput_len = 1024
|
|
||||||
|
|
||||||
path = os.path.expanduser("~/huggingface/Qwen3-0.6B/")
|
|
||||||
llm = LLM(path, enforce_eager=False, max_model_len=4096)
|
llm = LLM(path, enforce_eager=False, max_model_len=4096)
|
||||||
|
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
# Warmup
|
||||||
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)]
|
|
||||||
# uncomment the following line for vllm
|
|
||||||
# prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
|
||||||
|
|
||||||
llm.generate(["Benchmark: "], SamplingParams())
|
llm.generate(["Benchmark: "], SamplingParams())
|
||||||
t = time.time()
|
|
||||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
print("=" * 60)
|
||||||
t = (time.time() - t)
|
print("Prefill Benchmark")
|
||||||
total_tokens = sum(sp.max_tokens for sp in sampling_params)
|
print("=" * 60)
|
||||||
throughput = total_tokens / t
|
bench_prefill(llm, num_seqs=1, input_len=1024)
|
||||||
print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
bench_prefill(llm, num_seqs=1, input_len=2048)
|
||||||
|
bench_prefill(llm, num_seqs=1, input_len=4095)
|
||||||
|
bench_prefill(llm, num_seqs=16, input_len=1024)
|
||||||
|
bench_prefill(llm, num_seqs=64, input_len=1024)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Decode Benchmark")
|
||||||
|
print("=" * 60)
|
||||||
|
bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024)
|
||||||
|
bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
63
bench_vllm.py
Normal file
63
bench_vllm.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import os
|
||||||
|
os.environ["VLLM_USE_V1"] = "1"
|
||||||
|
import time
|
||||||
|
from random import randint, seed
|
||||||
|
from vllm import LLM, SamplingParams
|
||||||
|
|
||||||
|
|
||||||
|
def bench_decode(llm, num_seqs, max_input_len, max_output_len):
|
||||||
|
"""Benchmark decode performance (original test)"""
|
||||||
|
seed(0)
|
||||||
|
prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)]
|
||||||
|
sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_output_len)) for _ in range(num_seqs)]
|
||||||
|
prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
|
t = time.time() - t
|
||||||
|
total_output_tokens = sum(sp.max_tokens for sp in sampling_params)
|
||||||
|
throughput = total_output_tokens / t
|
||||||
|
print(f"[Decode] Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||||
|
|
||||||
|
|
||||||
|
def bench_prefill(llm, num_seqs, input_len):
|
||||||
|
"""Benchmark prefill performance"""
|
||||||
|
seed(0)
|
||||||
|
# Fixed length input, minimal output to focus on prefill
|
||||||
|
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||||
|
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
||||||
|
prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
||||||
|
|
||||||
|
t = time.time()
|
||||||
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
|
t = time.time() - t
|
||||||
|
total_input_tokens = num_seqs * input_len
|
||||||
|
throughput = total_input_tokens / t
|
||||||
|
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
|
llm = LLM(path, enforce_eager=False, max_model_len=4096, max_num_seqs=128, gpu_memory_utilization=0.9)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
llm.generate([dict(prompt_token_ids=[0])], SamplingParams())
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Prefill Benchmark")
|
||||||
|
print("=" * 60)
|
||||||
|
bench_prefill(llm, num_seqs=1, input_len=1024)
|
||||||
|
bench_prefill(llm, num_seqs=1, input_len=2048)
|
||||||
|
bench_prefill(llm, num_seqs=1, input_len=4095)
|
||||||
|
bench_prefill(llm, num_seqs=16, input_len=1024)
|
||||||
|
bench_prefill(llm, num_seqs=64, input_len=1024)
|
||||||
|
|
||||||
|
print("=" * 60)
|
||||||
|
print("Decode Benchmark")
|
||||||
|
print("=" * 60)
|
||||||
|
bench_decode(llm, num_seqs=1, max_input_len=1024, max_output_len=1024)
|
||||||
|
bench_decode(llm, num_seqs=256, max_input_len=1024, max_output_len=1024)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -3,7 +3,7 @@ from torch import nn
|
|||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user