diff --git a/nanovllm/engine/block_manager.py b/nanovllm/engine/block_manager.py index 02bc682..4d674d1 100644 --- a/nanovllm/engine/block_manager.py +++ b/nanovllm/engine/block_manager.py @@ -5,14 +5,6 @@ import numpy as np from nanovllm.engine.sequence import Sequence -def compute_hash(token_ids: list[int], prefix: int = -1): - h = xxhash.xxh64() - if prefix != -1: - h.update(prefix.to_bytes(8, "little")) - h.update(np.array(token_ids).tobytes()) - return h.intdigest() - - class Block: def __init__(self, block_id): @@ -22,7 +14,6 @@ class Block: self.token_ids = [] def update(self, hash: int, token_ids: list[int]): - assert hash != -1 self.hash = hash self.token_ids = token_ids @@ -42,7 +33,15 @@ class BlockManager: self.free_block_ids: deque[int] = deque(range(num_blocks)) self.used_block_ids: set[int] = set() - def _allocate_block(self, block_id: int): + @classmethod + def compute_hash(cls, token_ids: list[int], prefix: int = -1): + h = xxhash.xxh64() + if prefix != -1: + h.update(prefix.to_bytes(8, "little")) + h.update(np.array(token_ids).tobytes()) + return h.intdigest() + + def _allocate_block(self, block_id: int) -> Block: block = self.blocks[block_id] assert block.ref_count == 0 block.reset() @@ -50,12 +49,12 @@ class BlockManager: self.used_block_ids.add(block_id) return self.blocks[block_id] - def _deallocate_block(self, block_id: int): + def _deallocate_block(self, block_id: int) -> Block: assert self.blocks[block_id].ref_count == 0 self.used_block_ids.remove(block_id) self.free_block_ids.append(block_id) - def can_allocate(self, seq: Sequence): + def can_allocate(self, seq: Sequence) -> bool: return len(self.free_block_ids) >= seq.num_blocks def allocate(self, seq: Sequence): @@ -64,7 +63,7 @@ class BlockManager: cache_miss = False for i in range(seq.num_blocks): token_ids = seq.block(i) - h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1 + h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1 block_id = self.hash_to_block_id.get(h, -1) if block_id == -1 or self.blocks[block_id].token_ids != token_ids: cache_miss = True @@ -92,7 +91,7 @@ class BlockManager: seq.num_cached_tokens = 0 seq.block_table.clear() - def can_append(self, seq: Sequence): + def can_append(self, seq: Sequence) -> bool: return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) def may_append(self, seq: Sequence): @@ -107,7 +106,7 @@ class BlockManager: assert last_block.hash == -1 token_ids = seq.block(seq.num_blocks-1) prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 - h = compute_hash(token_ids, prefix) + h = self.compute_hash(token_ids, prefix) last_block.update(h, token_ids) self.hash_to_block_id[h] = last_block.block_id else: diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index 6f838a6..3f5636a 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -6,10 +6,9 @@ from multiprocessing.shared_memory import SharedMemory from nanovllm.config import Config from nanovllm.engine.sequence import Sequence -from nanovllm.utils.context import set_context, get_context, reset_context -from nanovllm.utils.memory import get_gpu_memory from nanovllm.models.qwen3 import Qwen3ForCausalLM from nanovllm.layers.sampler import Sampler +from nanovllm.utils.context import set_context, get_context, reset_context from nanovllm.utils.loader import load_model @@ -93,11 +92,11 @@ class ModelRunner: def allocate_kv_cache(self, gpu_memory_utilization): config = self.config hf_config = config.hf_config - total, used, _ = get_gpu_memory() - free = total * gpu_memory_utilization - used + free, total = torch.cuda.mem_get_info() + used = total - free num_kv_heads = hf_config.num_key_value_heads // self.world_size block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize - config.num_kvcache_blocks = int(free) // block_bytes + config.num_kvcache_blocks = int(total * gpu_memory_utilization - used) // block_bytes self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) layer_id = 0 for module in self.model.modules(): @@ -142,7 +141,6 @@ class ModelRunner: end = start + seq.last_block_num_tokens slot_mapping.extend(list(range(start, end))) assert len(input_ids) == len(slot_mapping) - assert len(input_ids) == cu_seqlens_q[-1] if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache block_tables = self.prepare_block_tables(seqs) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) diff --git a/nanovllm/engine/scheduler.py b/nanovllm/engine/scheduler.py index 1d96b1b..5bc19fe 100644 --- a/nanovllm/engine/scheduler.py +++ b/nanovllm/engine/scheduler.py @@ -53,10 +53,8 @@ class Scheduler: num_seqs += 1 self.block_manager.may_append(seq) scheduled_seqs.append(seq) - running = deque(scheduled_seqs) - running.extend(self.running) - self.running = running assert scheduled_seqs + self.running.extendleft(reversed(scheduled_seqs)) return scheduled_seqs, False def preempt(self, seq: Sequence): diff --git a/nanovllm/layers/embed_head.py b/nanovllm/layers/embed_head.py index 6422337..25241fb 100644 --- a/nanovllm/layers/embed_head.py +++ b/nanovllm/layers/embed_head.py @@ -21,7 +21,6 @@ class VocabParallelEmbedding(nn.Module): self.num_embeddings_per_partition = self.num_embeddings // self.tp_size self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition - self.embedding_dim = embedding_dim self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim)) self.weight.weight_loader = self.weight_loader diff --git a/nanovllm/layers/linear.py b/nanovllm/layers/linear.py index 0ae6eed..f3923cc 100755 --- a/nanovllm/layers/linear.py +++ b/nanovllm/layers/linear.py @@ -64,12 +64,6 @@ class ColumnParallelLinear(LinearBase): super().__init__(input_size, output_size, 0) self.input_size_per_partition = input_size self.output_size_per_partition = divide(output_size, self.tp_size) - self.output_partition_sizes = [self.output_size_per_partition] - if hasattr(self, "output_sizes"): - self.output_partition_sizes = [ - divide(output_size, self.tp_size) - for output_size in self.output_sizes - ] self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size)) self.weight.weight_loader = self.weight_loader @@ -122,23 +116,14 @@ class QKVParallelLinear(ColumnParallelLinear): total_num_kv_heads: int | None = None, bias: bool = False, ): - self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads - if total_num_kv_heads is None: - total_num_kv_heads = total_num_heads - self.total_num_kv_heads = total_num_kv_heads + self.total_num_kv_heads = total_num_kv_heads or total_num_heads tp_size = dist.get_world_size() self.num_heads = divide(self.total_num_heads, tp_size) self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) - input_size = self.hidden_size - output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - self.output_sizes = [ - self.num_heads * self.head_size * tp_size, # q_proj - self.num_kv_heads * self.head_size * tp_size, # k_proj - self.num_kv_heads * self.head_size * tp_size, # v_proj - ] - + input_size = hidden_size + output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size super().__init__(input_size, output_size, bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): @@ -170,7 +155,6 @@ class RowParallelLinear(LinearBase): super().__init__(input_size, output_size, 1) self.input_size_per_partition = divide(input_size, self.tp_size) self.output_size_per_partition = output_size - self.output_partition_sizes = [output_size] self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition)) self.weight.weight_loader = self.weight_loader diff --git a/nanovllm/layers/rotary_embedding.py b/nanovllm/layers/rotary_embedding.py index 1f483e7..c473420 100644 --- a/nanovllm/layers/rotary_embedding.py +++ b/nanovllm/layers/rotary_embedding.py @@ -1,5 +1,4 @@ from functools import lru_cache - import torch from torch import nn @@ -28,12 +27,9 @@ class RotaryEmbedding(nn.Module): ) -> None: super().__init__() self.head_size = head_size - self.rotary_dim = rotary_dim assert rotary_dim == head_size - self.max_position_embeddings = max_position_embeddings - self.base = base - inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) - t = torch.arange(self.max_position_embeddings, dtype=torch.float) + inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim)) + t = torch.arange(max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() @@ -47,8 +43,7 @@ class RotaryEmbedding(nn.Module): query: torch.Tensor, key: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - positions = positions.flatten() - num_tokens = positions.shape[0] + num_tokens = positions.size(0) cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) query_shape = query.shape diff --git a/nanovllm/models/qwen3.py b/nanovllm/models/qwen3.py index c306be5..03c5d6e 100755 --- a/nanovllm/models/qwen3.py +++ b/nanovllm/models/qwen3.py @@ -8,7 +8,7 @@ from nanovllm.layers.attention import Attention from nanovllm.layers.layernorm import RMSNorm from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear from nanovllm.layers.rotary_embedding import get_rope -from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead +from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead class Qwen3Attention(nn.Module): @@ -26,19 +26,17 @@ class Qwen3Attention(nn.Module): rope_scaling: tuple | None = None, ) -> None: super().__init__() - self.hidden_size = hidden_size tp_size = dist.get_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads assert self.total_num_kv_heads % tp_size == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.num_kv_heads = self.total_num_kv_heads // tp_size self.head_dim = head_dim or hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.rope_theta = rope_theta self.qkv_proj = QKVParallelLinear( hidden_size, @@ -57,13 +55,15 @@ class Qwen3Attention(nn.Module): self.head_dim, rotary_dim=self.head_dim, max_position=max_position, - base=self.rope_theta, + base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + self.num_kv_heads, + ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) @@ -122,9 +122,8 @@ class Qwen3DecoderLayer(nn.Module): config: Qwen3Config, ) -> None: super().__init__() - self.hidden_size = config.hidden_size self.self_attn = Qwen3Attention( - hidden_size=self.hidden_size, + hidden_size=config.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, max_position=config.max_position_embeddings, @@ -139,10 +138,8 @@ class Qwen3DecoderLayer(nn.Module): intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, @@ -155,10 +152,7 @@ class Qwen3DecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) + hidden_states = self.self_attn(positions, hidden_states) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -169,9 +163,8 @@ class Qwen3Model(nn.Module): def __init__( self, config: Qwen3Config, - ): + ) -> None: super().__init__() - self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -184,11 +177,7 @@ class Qwen3Model(nn.Module): hidden_states = self.embed_tokens(input_ids) residual = None for layer in self.layers: - hidden_states, residual = layer( - positions, - hidden_states, - residual, - ) + hidden_states, residual = layer(positions, hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -205,12 +194,11 @@ class Qwen3ForCausalLM(nn.Module): def __init__( self, config: Qwen3Config - ): + ) -> None: super().__init__() self.model = Qwen3Model(config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.tie_word_embeddings = config.tie_word_embeddings - if self.tie_word_embeddings: + if config.tie_word_embeddings: self.lm_head.weight.data = self.model.embed_tokens.weight.data def forward( diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index a9540b7..2281888 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -1,4 +1,3 @@ -from contextlib import contextmanager from dataclasses import dataclass import torch diff --git a/nanovllm/utils/memory.py b/nanovllm/utils/memory.py deleted file mode 100644 index 83f7729..0000000 --- a/nanovllm/utils/memory.py +++ /dev/null @@ -1,18 +0,0 @@ -import os -import torch -from pynvml import * - - -def get_gpu_memory(): - torch.cuda.synchronize() - nvmlInit() - visible_device = list(map(int, os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(','))) - cuda_device_idx = torch.cuda.current_device() - cuda_device_idx = visible_device[cuda_device_idx] - handle = nvmlDeviceGetHandleByIndex(cuda_device_idx) - mem_info = nvmlDeviceGetMemoryInfo(handle) - total_memory = mem_info.total - used_memory = mem_info.used - free_memory = mem_info.free - nvmlShutdown() - return total_memory, used_memory, free_memory