import torch from nanovllm.config import Config from nanovllm.engine.sequence import Sequence from nanovllm.utils.context import set_context, get_context, reset_context from nanovllm.utils.memory import get_gpu_memory from nanovllm.models.qwen3 import Qwen3ForCausalLM from nanovllm.layers.sampler import Sampler class ModelRunner: def __init__(self, config: Config): self.config = config hf_config = config.hf_config self.block_size = config.kvcache_block_size self.enforce_eager = config.enforce_eager default_dtype = torch.get_default_dtype() torch.set_default_dtype(hf_config.torch_dtype) torch.set_default_device("cuda") self.model = Qwen3ForCausalLM(hf_config) self.model.load_weights(config.model) self.sampler = Sampler() self.allocate_kv_cache(config.gpu_memory_utilization) if not self.enforce_eager: self.capture_model() torch.set_default_device("cpu") torch.set_default_dtype(default_dtype) def allocate_kv_cache(self, gpu_memory_utilization): config = self.config hf_config = config.hf_config total, used, _ = get_gpu_memory() free = total * gpu_memory_utilization - used block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * hf_config.num_key_value_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize config.num_kvcache_blocks = int(free * 1e6) // block_bytes self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, hf_config.num_key_value_heads, hf_config.head_dim) layer_id = 0 for module in self.model.modules(): if hasattr(module, "k_cache") and hasattr(module, "v_cache"): module.k_cache = self.kv_cache[0, layer_id] module.v_cache = self.kv_cache[1, layer_id] layer_id += 1 def preare_block_tables(self, seqs: list[Sequence]): max_len = max(len(seq.block_table) for seq in seqs) block_tables = [ seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs ] block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) return block_tables def prepare_prefill(self, seqs: list[Sequence]): input_ids = [] positions = [] cu_seqlens_q = [0] cu_seqlens_k = [0] max_seqlen_q = 0 max_seqlen_k = 0 slot_mapping = [] context_lens = None block_tables = None for seq in seqs: seqlen = len(seq) input_ids.extend(seq[seq.num_cached_tokens:]) positions.extend(list(range(seq.num_cached_tokens, len(seq)))) seqlen_q = seqlen - seq.num_cached_tokens seqlen_k = seqlen cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) for i in range(seq.num_cached_blocks, seq.num_blocks): start = seq.block_table[i] * self.block_size if i != seq.num_blocks - 1: end = start + self.block_size else: end = start + len(seq.last_block()) slot_mapping.extend(list(range(start, end))) assert len(input_ids) == len(slot_mapping) assert len(input_ids) == cu_seqlens_q[-1] if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache context_lens = torch.tensor([len(seq) for seq in seqs], dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) block_tables = self.preare_block_tables(seqs) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) return input_ids, positions def prepare_decode(self, seqs: list[Sequence]): input_ids = [] positions = [] slot_mapping = [] context_lens = [] for seq in seqs: input_ids.append(seq.last_token) positions.append(len(seq)) context_lens.append(len(seq)) slot_mapping.append(seq.block_table[-1] * self.block_size + len(seq.last_block())) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) block_tables = self.preare_block_tables(seqs) set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) return input_ids, positions def prepare_sample(self, seqs: list[Sequence]): temperatures = [] for seq in seqs: temperatures.append(seq.temperature) temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True) return temperatures @torch.inference_mode() def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill): if is_prefill or self.enforce_eager or input_ids.size(0) > 256: return self.model.compute_logits(self.model(input_ids, positions)) else: bs = input_ids.size(0) context = get_context() self.reset_graph_vars() graph = self.graphs[next(x for x in self.graph_bs if x >= bs)] graph_vars = self.graph_vars graph_vars["input_ids"][:bs] = input_ids graph_vars["positions"][:bs] = positions graph_vars["slot_mapping"][:bs] = context.slot_mapping graph_vars["context_lens"][:bs] = context.context_lens graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables graph.replay() return self.model.compute_logits(graph_vars["outputs"][:bs]) def reset_graph_vars(self): graph_vars = self.graph_vars graph_vars["input_ids"].zero_() graph_vars["positions"].zero_() graph_vars["slot_mapping"].zero_() graph_vars["context_lens"].zero_() graph_vars["block_tables"].zero_() def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) temperatures = self.prepare_sample(seqs) logits = self.run_model(input_ids, positions, is_prefill) token_ids = self.sampler(logits, temperatures).tolist() reset_context() return token_ids @torch.inference_mode() def capture_model(self): get_rng_state = torch.cuda.get_rng_state set_rng_state = torch.cuda.set_rng_state rng_state = torch.cuda.get_rng_state() torch.cuda.get_rng_state = lambda: rng_state torch.cuda.set_rng_state = lambda _: None config = self.config hf_config = config.hf_config max_bs = min(self.config.max_num_seqs, 256) max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size input_ids = torch.zeros(max_bs, dtype=torch.int64) positions = torch.zeros(max_bs, dtype=torch.int64) slot_mapping = torch.zeros(max_bs, dtype=torch.int32) context_lens = torch.zeros(max_bs, dtype=torch.int32) block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32) outputs = torch.zeros(max_bs, hf_config.hidden_size) self.graph_bs = [1, 2, 4, 8, 16] + list(range(32, max_bs + 1, 32)) self.graphs = {} self.graph_pool = None for bs in reversed(self.graph_bs): graph = torch.cuda.CUDAGraph() set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup with torch.cuda.graph(graph, self.graph_pool): outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture if self.graph_pool is None: self.graph_pool = graph.pool() self.graphs[bs] = graph torch.cuda.synchronize() reset_context() self.graph_vars = dict( input_ids=input_ids, positions=positions, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables, outputs=outputs, ) torch.cuda.get_rng_state = get_rng_state torch.cuda.set_rng_state = set_rng_state