Files
nano-vllm/nanovllm/engine/model_runner.py

582 lines
26 KiB
Python

import pickle
import torch
import torch.distributed as dist
from multiprocessing.synchronize import Event
from multiprocessing.shared_memory import SharedMemory
from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence
from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.layers.sampler import Sampler
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.loader import load_model
from nanovllm.kvcache import create_kvcache_manager, KVCacheManager
class ModelRunner:
def __init__(self, config: Config, rank: int, event: Event | list[Event]):
self.config = config
hf_config = config.hf_config
self.block_size = config.kvcache_block_size
self.enforce_eager = config.enforce_eager
self.world_size = config.tensor_parallel_size
self.rank = rank
self.event = event
dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank)
torch.cuda.set_device(rank)
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda")
self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)
self.sampler = Sampler()
self.warmup_model()
self.allocate_kv_cache()
if not self.enforce_eager:
self.capture_cudagraph()
torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)
if self.world_size > 1:
if rank == 0:
self.shm = SharedMemory(name="nanovllm", create=True, size=2**20)
dist.barrier()
else:
dist.barrier()
self.shm = SharedMemory(name="nanovllm")
self.loop()
def exit(self):
if self.world_size > 1:
self.shm.close()
dist.barrier()
if self.rank == 0:
self.shm.unlink()
if not self.enforce_eager:
del self.graphs, self.graph_pool
torch.cuda.synchronize()
dist.destroy_process_group()
def loop(self):
while True:
method_name, args = self.read_shm()
self.call(method_name, *args)
if method_name == "exit":
break
def read_shm(self):
assert self.world_size > 1 and self.rank > 0
self.event.wait()
n = int.from_bytes(self.shm.buf[0:4], "little")
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
self.event.clear()
return method_name, args
def write_shm(self, method_name, *args):
assert self.world_size > 1 and self.rank == 0
data = pickle.dumps([method_name, *args])
n = len(data)
self.shm.buf[0:4] = n.to_bytes(4, "little")
self.shm.buf[4:n+4] = data
for event in self.event:
event.set()
def call(self, method_name, *args):
if self.world_size > 1 and self.rank == 0:
self.write_shm(method_name, *args)
method = getattr(self, method_name, None)
return method(*args)
def warmup_model(self):
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len
num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs)
seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)]
self.run(seqs, True)
torch.cuda.empty_cache()
def allocate_kv_cache(self):
config = self.config
hf_config = config.hf_config
free, total = torch.cuda.mem_get_info()
used = total - free
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
num_kv_heads = hf_config.num_key_value_heads // self.world_size
head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
# Calculate GPU block count
num_gpu_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert num_gpu_blocks > 0
if config.enable_cpu_offload:
# Calculate CPU blocks based on cpu_memory_gb
cpu_bytes = int(config.cpu_memory_gb * 1024**3)
num_cpu_blocks = cpu_bytes // block_bytes
config.num_gpu_kvcache_blocks = num_gpu_blocks
config.num_cpu_kvcache_blocks = num_cpu_blocks
# For backward compatibility
config.num_kvcache_blocks = num_gpu_blocks + num_cpu_blocks
else:
config.num_kvcache_blocks = num_gpu_blocks
config.num_gpu_kvcache_blocks = num_gpu_blocks
config.num_cpu_kvcache_blocks = 0
# Create KV cache manager using factory
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
# Allocate cache through manager
self.kvcache_manager.allocate_cache(
num_layers=hf_config.num_hidden_layers,
num_kv_heads=num_kv_heads,
head_dim=head_dim,
dtype=hf_config.torch_dtype,
)
# Bind layer caches to attention modules and set layer_id
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
k_cache, v_cache = self.kvcache_manager.get_layer_cache(layer_id)
module.k_cache = k_cache
module.v_cache = v_cache
# Set layer_id for chunked prefill support
if hasattr(module, "layer_id"):
module.layer_id = layer_id
layer_id += 1
def prepare_block_tables(self, seqs: list[Sequence]):
max_len = max(len(seq.block_table) for seq in seqs)
block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs]
block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
return block_tables
def prepare_prefill(self, seqs: list[Sequence], chunk_info: list[tuple] = None):
"""
Prepare inputs for prefill.
Args:
seqs: List of sequences to prefill
chunk_info: Optional chunked prefill info from get_gpu_block_tables_partial().
If provided, only process blocks in the chunk.
Format: [(gpu_block_ids, start_block_idx, end_block_idx), ...]
"""
# Check if any sequence has blocks (not warmup)
has_blocks = any(seq.block_table for seq in seqs)
gpu_block_tables = None
if has_blocks and hasattr(self, 'kvcache_manager'):
if chunk_info is None:
# Standard prefill - try to get all blocks
# This may fail if GPU doesn't have enough capacity
self.kvcache_manager.prepare_for_attention(seqs, is_prefill=True)
gpu_block_tables = self.kvcache_manager.get_gpu_block_tables(seqs)
else:
# Chunked prefill - use provided chunk info
gpu_block_tables = [info[0] for info in chunk_info]
input_ids = []
positions = []
cu_seqlens_q = [0]
cu_seqlens_k = [0]
max_seqlen_q = 0
max_seqlen_k = 0
slot_mapping = []
block_tables = None
for seq_idx, seq in enumerate(seqs):
if chunk_info is not None:
# Chunked prefill: only process blocks in the chunk
gpu_blocks, start_block_idx, end_block_idx = chunk_info[seq_idx]
if not gpu_blocks:
continue
# Calculate token range for this chunk
start_token = start_block_idx * self.block_size
end_token = min(end_block_idx * self.block_size, len(seq))
if end_block_idx == seq.num_blocks:
# Last chunk includes partial last block
end_token = len(seq)
# Input tokens for this chunk
chunk_tokens = seq[start_token:end_token]
input_ids.extend(chunk_tokens)
positions.extend(list(range(start_token, end_token)))
seqlen_q = end_token - start_token
seqlen_k = end_token # Context includes all tokens up to this point
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
# Slot mapping for blocks in this chunk
for i, gpu_block_id in enumerate(gpu_blocks):
block_idx = start_block_idx + i
start = gpu_block_id * self.block_size
if block_idx != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
else:
# Standard prefill
seqlen = len(seq)
input_ids.extend(seq[seq.num_cached_tokens:])
positions.extend(list(range(seq.num_cached_tokens, seqlen)))
seqlen_q = seqlen - seq.num_cached_tokens
seqlen_k = seqlen
cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q)
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
max_seqlen_q = max(seqlen_q, max_seqlen_q)
max_seqlen_k = max(seqlen_k, max_seqlen_k)
if not seq.block_table: # warmup
continue
# Use GPU physical block IDs for slot mapping
gpu_blocks = gpu_block_tables[seq_idx]
for i in range(seq.num_cached_blocks, seq.num_blocks):
start = gpu_blocks[i] * self.block_size
if i != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
if cu_seqlens_k[-1] > cu_seqlens_q[-1] and gpu_block_tables: # prefix cache
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
# Prepare KV cache (updates gather_indices for hybrid manager)
if hasattr(self, 'kvcache_manager'):
self.kvcache_manager.prepare_for_attention(seqs, is_prefill=False)
# Get GPU physical block tables
gpu_block_tables = self.kvcache_manager.get_gpu_block_tables(seqs)
else:
gpu_block_tables = [list(seq.block_table) for seq in seqs]
input_ids = []
positions = []
slot_mapping = []
context_lens = []
for seq_idx, seq in enumerate(seqs):
input_ids.append(seq.last_token)
positions.append(len(seq) - 1)
context_lens.append(len(seq))
# Use GPU physical block ID for slot mapping
gpu_blocks = gpu_block_tables[seq_idx]
slot_mapping.append(gpu_blocks[-1] * self.block_size + seq.last_block_num_tokens - 1)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Use GPU physical block tables for attention
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
return input_ids, positions
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
"""Prepare block tables tensor from GPU physical block IDs."""
max_len = max(len(bt) for bt in gpu_block_tables)
padded = [bt + [-1] * (max_len - len(bt)) for bt in gpu_block_tables]
return torch.tensor(padded, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
def prepare_sample(self, seqs: list[Sequence]):
temperatures = []
for seq in seqs:
temperatures.append(seq.temperature)
temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True)
return temperatures
@torch.inference_mode()
def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool):
if is_prefill or self.enforce_eager or input_ids.size(0) > 512:
return self.model.compute_logits(self.model(input_ids, positions))
else:
bs = input_ids.size(0)
context = get_context()
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
graph_vars = self.graph_vars
graph_vars["input_ids"][:bs] = input_ids
graph_vars["positions"][:bs] = positions
graph_vars["slot_mapping"].fill_(-1)
graph_vars["slot_mapping"][:bs] = context.slot_mapping
graph_vars["context_lens"].zero_()
graph_vars["context_lens"][:bs] = context.context_lens
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
graph.replay()
return self.model.compute_logits(graph_vars["outputs"][:bs])
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
# Check if chunked prefill is needed
if is_prefill and hasattr(self, 'kvcache_manager'):
needs_chunked = any(
hasattr(self.kvcache_manager, 'needs_chunked_prefill') and
self.kvcache_manager.needs_chunked_prefill(seq)
for seq in seqs if seq.block_table
)
if needs_chunked:
return self.run_chunked_prefill(seqs)
# Check if chunked decode is needed
if not is_prefill and hasattr(self, 'kvcache_manager'):
needs_chunked = any(
hasattr(self.kvcache_manager, 'needs_chunked_decode') and
self.kvcache_manager.needs_chunked_decode(seq)
for seq in seqs if seq.block_table
)
if needs_chunked:
return self.run_chunked_decode(seqs)
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
logits = self.run_model(input_ids, positions, is_prefill)
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
reset_context()
return token_ids
def run_chunked_prefill(self, seqs: list[Sequence]) -> list[int]:
"""
Run prefill in chunks when sequences exceed GPU capacity.
For each chunk:
1. Process tokens through model forward pass
2. At each attention layer:
- Load previous KV from CPU (handled by attention layer)
- Compute attention with online softmax merging
- Store current KV to GPU cache
3. After chunk completes, offload KV to CPU
4. Load next chunk's blocks to GPU
"""
import sys
# Currently only supporting single sequence for chunked prefill
assert len(seqs) == 1, "Chunked prefill only supports single sequence"
seq = seqs[0]
total_blocks = seq.num_blocks
print(f"[Chunked Prefill] Starting: {total_blocks} total blocks, "
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
chunk_num = 0
logits = None
while True:
# Get chunk info (which blocks are on GPU and not yet prefilled)
chunk_info = self.kvcache_manager.get_gpu_block_tables_partial(seqs)
gpu_blocks, start_block_idx, end_block_idx = chunk_info[0]
if not gpu_blocks:
# No more blocks to process
break
chunk_num += 1
chunk_tokens = (end_block_idx - start_block_idx) * self.block_size
if end_block_idx == seq.num_blocks:
# Last block may be partial
chunk_tokens = len(seq) - start_block_idx * self.block_size
print(f"[Chunked Prefill] Chunk {chunk_num}: blocks {start_block_idx}-{end_block_idx-1}, "
f"~{chunk_tokens} tokens", file=sys.stderr)
# Prepare inputs for this chunk
input_ids, positions = self._prepare_chunked_prefill(seq, gpu_blocks, start_block_idx, end_block_idx)
if input_ids.numel() == 0:
print(f"[Chunked Prefill] No input tokens, breaking", file=sys.stderr)
break
print(f"[Chunked Prefill] Running model with {input_ids.numel()} tokens...", file=sys.stderr)
# Run model forward pass
logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context()
print(f"[Chunked Prefill] Model forward complete", file=sys.stderr)
# Check if this is the last chunk
# Mark current chunk as prefilled and offload to CPU
self.kvcache_manager.complete_prefill_chunk(seq)
# Check if more chunks needed
if not self.kvcache_manager.needs_chunked_prefill(seq):
print(f"[Chunked Prefill] All chunks done, sampling", file=sys.stderr)
break
print(f"[Chunked Prefill] Chunk transfer complete, loading next...", file=sys.stderr)
# Sample from the last chunk's logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
if logits is not None:
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
else:
token_ids = [0] if self.rank == 0 else None
return token_ids
def run_chunked_decode(self, seqs: list[Sequence]) -> list[int]:
"""
Run decode with chunked attention when sequence exceeds GPU capacity.
For decode, we need attention over ALL previous tokens. With CPU offload,
we load KV chunks and compute attention incrementally.
"""
import sys
# Currently only supporting single sequence for chunked decode
assert len(seqs) == 1, "Chunked decode only supports single sequence"
seq = seqs[0]
total_blocks = len(seq.block_table)
print(f"[Chunked Decode] Sequence has {total_blocks} blocks, "
f"GPU slots: {self.kvcache_manager.num_gpu_slots}", file=sys.stderr)
# Prepare inputs
input_ids = torch.tensor([seq.last_token], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor([len(seq) - 1], dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
# Compute slot mapping for the new token
# Get the last block's GPU slot if it's on GPU, otherwise we need to handle it
last_logical_id = seq.block_table[-1]
last_block = self.kvcache_manager.logical_blocks[last_logical_id]
if last_block.location.name == "GPU":
slot = last_block.gpu_slot * self.block_size + seq.last_block_num_tokens - 1
else:
# Last block is on CPU - we need to bring it to GPU for writing the new token
# This is a special case - allocate a temporary GPU slot
# For simplicity, use a fixed slot (this might conflict, but for decode
# we only write 1 token so it should be ok)
print(f"[Chunked Decode] Warning: last block on CPU, using temp slot", file=sys.stderr)
slot = 0 # Use first slot temporarily
slot_mapping = torch.tensor([slot], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
context_len = torch.tensor([len(seq)], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Set up context for chunked decode
set_context(
is_prefill=False, # Decode mode
slot_mapping=slot_mapping,
context_lens=context_len,
is_chunked_prefill=True, # Use chunked attention
offload_engine=self.kvcache_manager,
chunked_seq=seq,
)
# Run model forward pass
logits = self.run_model(input_ids, positions, is_prefill=False)
reset_context()
# Sample
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None
return token_ids
def _prepare_chunked_prefill(
self,
seq: Sequence,
gpu_blocks: list[int],
start_block_idx: int,
end_block_idx: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Prepare inputs for a single chunk in chunked prefill.
Sets up context with is_chunked_prefill=True so attention layers
know to load previous KV from CPU.
"""
# Calculate token range for this chunk
start_token = start_block_idx * self.block_size
end_token = min(end_block_idx * self.block_size, len(seq))
# Input tokens for this chunk
input_ids = seq[start_token:end_token]
positions = list(range(start_token, end_token))
# Slot mapping for storing KV cache
slot_mapping = []
for i, gpu_block_id in enumerate(gpu_blocks):
block_idx = start_block_idx + i
start = gpu_block_id * self.block_size
if block_idx != seq.num_blocks - 1:
end = start + self.block_size
else:
end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end)))
# Trim slot_mapping to match actual token count
actual_tokens = end_token - start_token
slot_mapping = slot_mapping[:actual_tokens]
# Convert to tensors
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
# Set up context for chunked prefill
seqlen = actual_tokens
cu_seqlens_q = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
cu_seqlens_k = torch.tensor([0, seqlen], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=seqlen,
max_seqlen_k=seqlen,
slot_mapping=slot_mapping,
is_chunked_prefill=True,
offload_engine=self.kvcache_manager, # Pass manager for loading previous KV
chunked_seq=seq, # Pass sequence for loading previous KV
)
return input_ids, positions
@torch.inference_mode()
def capture_cudagraph(self):
config = self.config
hf_config = config.hf_config
max_bs = min(self.config.max_num_seqs, 512)
max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size
input_ids = torch.zeros(max_bs, dtype=torch.int64)
positions = torch.zeros(max_bs, dtype=torch.int64)
slot_mapping = torch.zeros(max_bs, dtype=torch.int32)
context_lens = torch.zeros(max_bs, dtype=torch.int32)
block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32)
outputs = torch.zeros(max_bs, hf_config.hidden_size)
self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16))
self.graphs = {}
self.graph_pool = None
for bs in reversed(self.graph_bs):
graph = torch.cuda.CUDAGraph()
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
with torch.cuda.graph(graph, self.graph_pool):
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
if self.graph_pool is None:
self.graph_pool = graph.pool()
self.graphs[bs] = graph
torch.cuda.synchronize()
reset_context()
self.graph_vars = dict(
input_ids=input_ids,
positions=positions,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
outputs=outputs,
)