From df99418f7d6ca676550f4372cdc6e1521ce8c33d Mon Sep 17 00:00:00 2001 From: GeeeekExplorer <2651904866@qq.com> Date: Sun, 31 Aug 2025 19:44:57 +0800 Subject: [PATCH] simplify --- example.py | 1 - nanovllm/engine/block_manager.py | 1 - nanovllm/engine/llm_engine.py | 2 +- nanovllm/engine/model_runner.py | 13 +++--- nanovllm/layers/attention.py | 8 +--- nanovllm/layers/embed_head.py | 9 +---- nanovllm/layers/layernorm.py | 5 +-- nanovllm/layers/linear.py | 61 ++++++++++------------------- nanovllm/layers/rotary_embedding.py | 15 ++----- nanovllm/layers/sampler.py | 6 +-- nanovllm/models/qwen3.py | 22 ++++------- 11 files changed, 47 insertions(+), 96 deletions(-) diff --git a/example.py b/example.py index 33540f6..82e3917 100644 --- a/example.py +++ b/example.py @@ -18,7 +18,6 @@ def main(): [{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True, - enable_thinking=True ) for prompt in prompts ] diff --git a/nanovllm/engine/block_manager.py b/nanovllm/engine/block_manager.py index 4d674d1..65d725e 100644 --- a/nanovllm/engine/block_manager.py +++ b/nanovllm/engine/block_manager.py @@ -26,7 +26,6 @@ class Block: class BlockManager: def __init__(self, num_blocks: int, block_size: int): - assert num_blocks > 0 self.block_size = block_size self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] self.hash_to_block_id: dict[int, int] = dict() diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index 2d43b50..ed4df26 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -86,7 +86,7 @@ class LLMEngine: outputs[seq_id] = token_ids if use_tqdm: pbar.update(1) - outputs = [outputs[seq_id] for seq_id in sorted(outputs)] + outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())] outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] if use_tqdm: pbar.close() diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index aaa29cf..e9572eb 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -66,7 +66,7 @@ class ModelRunner: break def read_shm(self): - assert self.world_size > 1 and self.rank + assert self.world_size > 1 and self.rank > 0 self.event.wait() n = int.from_bytes(self.shm.buf[0:4], "little") method_name, *args = pickle.loads(self.shm.buf[4:n+4]) @@ -74,7 +74,7 @@ class ModelRunner: return method_name, args def write_shm(self, method_name, *args): - assert self.world_size > 1 and not self.rank + assert self.world_size > 1 and self.rank == 0 data = pickle.dumps([method_name, *args]) n = len(data) self.shm.buf[0:4] = n.to_bytes(4, "little") @@ -108,7 +108,7 @@ class ModelRunner: block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes assert config.num_kvcache_blocks > 0 - self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) + self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) layer_id = 0 for module in self.model.modules(): if hasattr(module, "k_cache") and hasattr(module, "v_cache"): @@ -141,7 +141,7 @@ class ModelRunner: cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) max_seqlen_q = max(seqlen_q, max_seqlen_q) max_seqlen_k = max(seqlen_k, max_seqlen_k) - if not seq.block_table: + if not seq.block_table: # warmup continue for i in range(seq.num_cached_blocks, seq.num_blocks): start = seq.block_table[i] * self.block_size @@ -194,12 +194,11 @@ class ModelRunner: context = get_context() graph = self.graphs[next(x for x in self.graph_bs if x >= bs)] graph_vars = self.graph_vars - for k, v in graph_vars.items(): - if k != "outputs": - v.zero_() graph_vars["input_ids"][:bs] = input_ids graph_vars["positions"][:bs] = positions + graph_vars["slot_mapping"].fill_(-1) graph_vars["slot_mapping"][:bs] = context.slot_mapping + graph_vars["context_lens"].zero_() graph_vars["context_lens"][:bs] = context.context_lens graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables graph.replay() diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index d036641..e416139 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -19,11 +19,12 @@ def store_kvcache_kernel( D: tl.constexpr, ): idx = tl.program_id(0) + slot = tl.load(slot_mapping_ptr + idx) + if slot == -1: return key_offsets = idx * key_stride + tl.arange(0, D) value_offsets = idx * value_stride + tl.arange(0, D) key = tl.load(key_ptr + key_offsets) value = tl.load(value_ptr + value_offsets) - slot = tl.load(slot_mapping_ptr + idx) cache_offsets = slot * D + tl.arange(0, D) tl.store(k_cache_ptr + cache_offsets, key) tl.store(v_cache_ptr + cache_offsets, value) @@ -56,10 +57,6 @@ class Attention(nn.Module): self.k_cache = self.v_cache = torch.tensor([]) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): - o: torch.Tensor - q = q.view(-1, self.num_heads, self.head_dim) - k = k.view(-1, self.num_kv_heads, self.head_dim) - v = v.view(-1, self.num_kv_heads, self.head_dim) context = get_context() k_cache, v_cache = self.k_cache, self.v_cache if k_cache.numel() and v_cache.numel(): @@ -75,5 +72,4 @@ class Attention(nn.Module): o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, cache_seqlens=context.context_lens, block_table=context.block_tables, softmax_scale=self.scale, causal=True) - o = o.view(-1, self.num_heads * self.head_dim) return o diff --git a/nanovllm/layers/embed_head.py b/nanovllm/layers/embed_head.py index 25241fb..84b3ab5 100644 --- a/nanovllm/layers/embed_head.py +++ b/nanovllm/layers/embed_head.py @@ -29,7 +29,6 @@ class VocabParallelEmbedding(nn.Module): shard_size = param_data.size(0) start_idx = self.tp_rank * shard_size loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) - assert param_data.size() == loaded_weight.size() param_data.copy_(loaded_weight) def forward(self, x: torch.Tensor): @@ -51,19 +50,15 @@ class ParallelLMHead(VocabParallelEmbedding): embedding_dim: int, bias: bool = False, ): + assert not bias super().__init__(num_embeddings, embedding_dim) - if bias: - self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition)) - self.bias.weight_loader = self.weight_loader - else: - self.register_parameter("bias", None) def forward(self, x: torch.Tensor): context = get_context() if context.is_prefill: last_indices = context.cu_seqlens_q[1:] - 1 x = x[last_indices].contiguous() - logits = F.linear(x, self.weight, self.bias) + logits = F.linear(x, self.weight) if self.tp_size > 1: all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None dist.gather(logits, all_logits, 0) diff --git a/nanovllm/layers/layernorm.py b/nanovllm/layers/layernorm.py index 32dcfa2..71bf419 100755 --- a/nanovllm/layers/layernorm.py +++ b/nanovllm/layers/layernorm.py @@ -10,7 +10,6 @@ class RMSNorm(nn.Module): eps: float = 1e-6, ) -> None: super().__init__() - self.hidden_size = hidden_size self.eps = eps self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -20,7 +19,7 @@ class RMSNorm(nn.Module): x: torch.Tensor, ) -> torch.Tensor: orig_dtype = x.dtype - x = x.to(torch.float32) + x = x.float() var = x.pow(2).mean(dim=-1, keepdim=True) x.mul_(torch.rsqrt(var + self.eps)) x = x.to(orig_dtype).mul_(self.weight) @@ -33,7 +32,7 @@ class RMSNorm(nn.Module): residual: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: orig_dtype = x.dtype - x = x.to(torch.float32).add_(residual.to(torch.float32)) + x = x.float().add_(residual.float()) residual = x.to(orig_dtype) var = x.pow(2).mean(dim=-1, keepdim=True) x.mul_(torch.rsqrt(var + self.eps)) diff --git a/nanovllm/layers/linear.py b/nanovllm/layers/linear.py index c625a9e..2a9e8d5 100755 --- a/nanovllm/layers/linear.py +++ b/nanovllm/layers/linear.py @@ -15,14 +15,20 @@ class LinearBase(nn.Module): self, input_size: int, output_size: int, + bias: bool = False, tp_dim: int | None = None, ): super().__init__() - self.input_size = input_size - self.output_size = output_size self.tp_dim = tp_dim self.tp_rank = dist.get_rank() self.tp_size = dist.get_world_size() + self.weight = nn.Parameter(torch.empty(output_size, input_size)) + self.weight.weight_loader = self.weight_loader + if bias: + self.bias = nn.Parameter(torch.empty(output_size)) + self.bias.weight_loader = self.weight_loader + else: + self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> torch.Tensor: raise NotImplementedError @@ -36,14 +42,7 @@ class ReplicatedLinear(LinearBase): output_size: int, bias: bool = False, ): - super().__init__(input_size, output_size) - self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size)) - self.weight.weight_loader = self.weight_loader - if bias: - self.bias = nn.Parameter(torch.empty(self.output_size)) - self.bias.weight_loader = self.weight_loader - else: - self.register_parameter("bias", None) + super().__init__(input_size, output_size, bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param.data.copy_(loaded_weight) @@ -60,17 +59,8 @@ class ColumnParallelLinear(LinearBase): output_size: int, bias: bool = False, ): - super().__init__(input_size, output_size, 0) - self.input_size_per_partition = input_size - self.output_size_per_partition = divide(output_size, self.tp_size) - - self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size)) - self.weight.weight_loader = self.weight_loader - if bias: - self.bias = nn.Parameter(torch.empty(self.output_size_per_partition)) - self.bias.weight_loader = self.weight_loader - else: - self.register_parameter("bias", None) + tp_size = dist.get_world_size() + super().__init__(input_size, divide(output_size, tp_size), bias, 0) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data @@ -92,7 +82,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): bias: bool = False, ): self.output_sizes = output_sizes - super().__init__(input_size, sum(output_sizes), bias=bias) + super().__init__(input_size, sum(output_sizes), bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int): param_data = param.data @@ -113,15 +103,13 @@ class QKVParallelLinear(ColumnParallelLinear): total_num_kv_heads: int | None = None, bias: bool = False, ): - self.head_size = head_size - self.total_num_heads = total_num_heads - self.total_num_kv_heads = total_num_kv_heads or total_num_heads tp_size = dist.get_world_size() - self.num_heads = divide(self.total_num_heads, tp_size) - self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) - input_size = hidden_size - output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size - super().__init__(input_size, output_size, bias) + total_num_kv_heads = total_num_kv_heads or total_num_heads + self.head_size = head_size + self.num_heads = divide(total_num_heads, tp_size) + self.num_kv_heads = divide(total_num_kv_heads, tp_size) + output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size + super().__init__(hidden_size, output_size, bias) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): param_data = param.data @@ -148,17 +136,8 @@ class RowParallelLinear(LinearBase): output_size: int, bias: bool = False, ): - super().__init__(input_size, output_size, 1) - self.input_size_per_partition = divide(input_size, self.tp_size) - self.output_size_per_partition = output_size - - self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition)) - self.weight.weight_loader = self.weight_loader - if bias: - self.bias = nn.Parameter(torch.empty(self.output_size)) - self.bias.weight_loader = self.weight_loader - else: - self.register_parameter("bias", None) + tp_size = dist.get_world_size() + super().__init__(divide(input_size, tp_size), output_size, bias, 1) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): param_data = param.data diff --git a/nanovllm/layers/rotary_embedding.py b/nanovllm/layers/rotary_embedding.py index c473420..998d116 100644 --- a/nanovllm/layers/rotary_embedding.py +++ b/nanovllm/layers/rotary_embedding.py @@ -8,9 +8,7 @@ def apply_rotary_emb( cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: - cos = cos.unsqueeze(-2) - sin = sin.unsqueeze(-2) - x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1) + x1, x2 = torch.chunk(x.float(), 2, dim=-1) y1 = x1 * cos - x2 * sin y2 = x2 * cos + x1 * sin return torch.cat((y1, y2), dim=-1).to(x.dtype) @@ -33,7 +31,7 @@ class RotaryEmbedding(nn.Module): freqs = torch.einsum("i,j -> ij", t, inv_freq) cos = freqs.cos() sin = freqs.sin() - cache = torch.cat((cos, sin), dim=-1) + cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1) self.register_buffer("cos_sin_cache", cache, persistent=False) @torch.compile @@ -43,15 +41,10 @@ class RotaryEmbedding(nn.Module): query: torch.Tensor, key: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: - num_tokens = positions.size(0) cos_sin = self.cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) - query_shape = query.shape - query = query.view(num_tokens, -1, self.head_size) - query = apply_rotary_emb(query, cos, sin).view(query_shape) - key_shape = key.shape - key = key.view(num_tokens, -1, self.head_size) - key = apply_rotary_emb(key, cos, sin).view(key_shape) + query = apply_rotary_emb(query, cos, sin) + key = apply_rotary_emb(key, cos, sin) return query, key diff --git a/nanovllm/layers/sampler.py b/nanovllm/layers/sampler.py index e4b9816..a5e7ddc 100644 --- a/nanovllm/layers/sampler.py +++ b/nanovllm/layers/sampler.py @@ -8,11 +8,9 @@ class Sampler(nn.Module): super().__init__() def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): - logits = logits.to(torch.float) + logits = logits.float() greedy_tokens = logits.argmax(dim=-1) logits.div_(temperatures.unsqueeze(dim=1)) probs = torch.softmax(logits, dim=-1, dtype=torch.float) - # logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) - epsilon = 1e-10 - sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + epsilon).argmax(dim=-1) + sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + 1e-10).argmax(dim=-1) return torch.where(temperatures == 0, greedy_tokens, sample_tokens) diff --git a/nanovllm/models/qwen3.py b/nanovllm/models/qwen3.py index cca456e..9c042fe 100755 --- a/nanovllm/models/qwen3.py +++ b/nanovllm/models/qwen3.py @@ -36,7 +36,7 @@ class Qwen3Attention(nn.Module): self.head_dim = head_dim or hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim - self.scaling = self.head_dim**-0.5 + self.scaling = self.head_dim ** -0.5 self.qkv_proj = QKVParallelLinear( hidden_size, @@ -73,15 +73,12 @@ class Qwen3Attention(nn.Module): ) -> torch.Tensor: qkv = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q_by_head = q.view(-1, self.num_heads, self.head_dim) - q_by_head = self.q_norm(q_by_head) - q = q_by_head.view(q.shape) - k_by_head = k.view(-1, self.num_kv_heads, self.head_dim) - k_by_head = self.k_norm(k_by_head) - k = k_by_head.view(k.shape) + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)) + v = v.view(-1, self.num_kv_heads, self.head_dim) q, k = self.rotary_emb(positions, q, k) o = self.attn(q, k, v) - output = self.o_proj(o) + output = self.o_proj(o.flatten(1, -1)) return output @@ -147,8 +144,7 @@ class Qwen3DecoderLayer(nn.Module): residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + hidden_states, residual = self.input_layernorm(hidden_states), hidden_states else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn(positions, hidden_states) @@ -205,12 +201,10 @@ class Qwen3ForCausalLM(nn.Module): input_ids: torch.Tensor, positions: torch.Tensor, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions) - return hidden_states + return self.model(input_ids, positions) def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor: - logits = self.lm_head(hidden_states) - return logits + return self.lm_head(hidden_states)