This commit is contained in:
GeeeekExplorer
2025-06-21 17:04:53 +08:00
parent ad4e95fbdc
commit cde3fc22c2
9 changed files with 42 additions and 100 deletions

View File

@@ -5,14 +5,6 @@ import numpy as np
from nanovllm.engine.sequence import Sequence from nanovllm.engine.sequence import Sequence
def compute_hash(token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
class Block: class Block:
def __init__(self, block_id): def __init__(self, block_id):
@@ -22,7 +14,6 @@ class Block:
self.token_ids = [] self.token_ids = []
def update(self, hash: int, token_ids: list[int]): def update(self, hash: int, token_ids: list[int]):
assert hash != -1
self.hash = hash self.hash = hash
self.token_ids = token_ids self.token_ids = token_ids
@@ -42,7 +33,15 @@ class BlockManager:
self.free_block_ids: deque[int] = deque(range(num_blocks)) self.free_block_ids: deque[int] = deque(range(num_blocks))
self.used_block_ids: set[int] = set() self.used_block_ids: set[int] = set()
def _allocate_block(self, block_id: int): @classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()
def _allocate_block(self, block_id: int) -> Block:
block = self.blocks[block_id] block = self.blocks[block_id]
assert block.ref_count == 0 assert block.ref_count == 0
block.reset() block.reset()
@@ -50,12 +49,12 @@ class BlockManager:
self.used_block_ids.add(block_id) self.used_block_ids.add(block_id)
return self.blocks[block_id] return self.blocks[block_id]
def _deallocate_block(self, block_id: int): def _deallocate_block(self, block_id: int) -> Block:
assert self.blocks[block_id].ref_count == 0 assert self.blocks[block_id].ref_count == 0
self.used_block_ids.remove(block_id) self.used_block_ids.remove(block_id)
self.free_block_ids.append(block_id) self.free_block_ids.append(block_id)
def can_allocate(self, seq: Sequence): def can_allocate(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= seq.num_blocks return len(self.free_block_ids) >= seq.num_blocks
def allocate(self, seq: Sequence): def allocate(self, seq: Sequence):
@@ -64,7 +63,7 @@ class BlockManager:
cache_miss = False cache_miss = False
for i in range(seq.num_blocks): for i in range(seq.num_blocks):
token_ids = seq.block(i) token_ids = seq.block(i)
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1 h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
block_id = self.hash_to_block_id.get(h, -1) block_id = self.hash_to_block_id.get(h, -1)
if block_id == -1 or self.blocks[block_id].token_ids != token_ids: if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
cache_miss = True cache_miss = True
@@ -92,7 +91,7 @@ class BlockManager:
seq.num_cached_tokens = 0 seq.num_cached_tokens = 0
seq.block_table.clear() seq.block_table.clear()
def can_append(self, seq: Sequence): def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
def may_append(self, seq: Sequence): def may_append(self, seq: Sequence):
@@ -107,7 +106,7 @@ class BlockManager:
assert last_block.hash == -1 assert last_block.hash == -1
token_ids = seq.block(seq.num_blocks-1) token_ids = seq.block(seq.num_blocks-1)
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
h = compute_hash(token_ids, prefix) h = self.compute_hash(token_ids, prefix)
last_block.update(h, token_ids) last_block.update(h, token_ids)
self.hash_to_block_id[h] = last_block.block_id self.hash_to_block_id[h] = last_block.block_id
else: else:

View File

@@ -6,10 +6,9 @@ from multiprocessing.shared_memory import SharedMemory
from nanovllm.config import Config from nanovllm.config import Config
from nanovllm.engine.sequence import Sequence from nanovllm.engine.sequence import Sequence
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.memory import get_gpu_memory
from nanovllm.models.qwen3 import Qwen3ForCausalLM from nanovllm.models.qwen3 import Qwen3ForCausalLM
from nanovllm.layers.sampler import Sampler from nanovllm.layers.sampler import Sampler
from nanovllm.utils.context import set_context, get_context, reset_context
from nanovllm.utils.loader import load_model from nanovllm.utils.loader import load_model
@@ -93,11 +92,11 @@ class ModelRunner:
def allocate_kv_cache(self, gpu_memory_utilization): def allocate_kv_cache(self, gpu_memory_utilization):
config = self.config config = self.config
hf_config = config.hf_config hf_config = config.hf_config
total, used, _ = get_gpu_memory() free, total = torch.cuda.mem_get_info()
free = total * gpu_memory_utilization - used used = total - free
num_kv_heads = hf_config.num_key_value_heads // self.world_size num_kv_heads = hf_config.num_key_value_heads // self.world_size
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(free) // block_bytes config.num_kvcache_blocks = int(total * gpu_memory_utilization - used) // block_bytes
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
layer_id = 0 layer_id = 0
for module in self.model.modules(): for module in self.model.modules():
@@ -142,7 +141,6 @@ class ModelRunner:
end = start + seq.last_block_num_tokens end = start + seq.last_block_num_tokens
slot_mapping.extend(list(range(start, end))) slot_mapping.extend(list(range(start, end)))
assert len(input_ids) == len(slot_mapping) assert len(input_ids) == len(slot_mapping)
assert len(input_ids) == cu_seqlens_q[-1]
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
block_tables = self.prepare_block_tables(seqs) block_tables = self.prepare_block_tables(seqs)
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)

View File

@@ -53,10 +53,8 @@ class Scheduler:
num_seqs += 1 num_seqs += 1
self.block_manager.may_append(seq) self.block_manager.may_append(seq)
scheduled_seqs.append(seq) scheduled_seqs.append(seq)
running = deque(scheduled_seqs)
running.extend(self.running)
self.running = running
assert scheduled_seqs assert scheduled_seqs
self.running.extendleft(reversed(scheduled_seqs))
return scheduled_seqs, False return scheduled_seqs, False
def preempt(self, seq: Sequence): def preempt(self, seq: Sequence):

View File

@@ -21,7 +21,6 @@ class VocabParallelEmbedding(nn.Module):
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
self.embedding_dim = embedding_dim
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim)) self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
self.weight.weight_loader = self.weight_loader self.weight.weight_loader = self.weight_loader

View File

@@ -64,12 +64,6 @@ class ColumnParallelLinear(LinearBase):
super().__init__(input_size, output_size, 0) super().__init__(input_size, output_size, 0)
self.input_size_per_partition = input_size self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size) self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size)) self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
self.weight.weight_loader = self.weight_loader self.weight.weight_loader = self.weight_loader
@@ -122,23 +116,14 @@ class QKVParallelLinear(ColumnParallelLinear):
total_num_kv_heads: int | None = None, total_num_kv_heads: int | None = None,
bias: bool = False, bias: bool = False,
): ):
self.hidden_size = hidden_size
self.head_size = head_size self.head_size = head_size
self.total_num_heads = total_num_heads self.total_num_heads = total_num_heads
if total_num_kv_heads is None: self.total_num_kv_heads = total_num_kv_heads or total_num_heads
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
tp_size = dist.get_world_size() tp_size = dist.get_world_size()
self.num_heads = divide(self.total_num_heads, tp_size) self.num_heads = divide(self.total_num_heads, tp_size)
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
input_size = self.hidden_size input_size = hidden_size
output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size, # q_proj
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
super().__init__(input_size, output_size, bias) super().__init__(input_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str): def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
@@ -170,7 +155,6 @@ class RowParallelLinear(LinearBase):
super().__init__(input_size, output_size, 1) super().__init__(input_size, output_size, 1)
self.input_size_per_partition = divide(input_size, self.tp_size) self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition)) self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
self.weight.weight_loader = self.weight_loader self.weight.weight_loader = self.weight_loader

View File

@@ -1,5 +1,4 @@
from functools import lru_cache from functools import lru_cache
import torch import torch
from torch import nn from torch import nn
@@ -28,12 +27,9 @@ class RotaryEmbedding(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.head_size = head_size self.head_size = head_size
self.rotary_dim = rotary_dim
assert rotary_dim == head_size assert rotary_dim == head_size
self.max_position_embeddings = max_position_embeddings inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
self.base = base t = torch.arange(max_position_embeddings, dtype=torch.float)
inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
sin = freqs.sin() sin = freqs.sin()
@@ -47,8 +43,7 @@ class RotaryEmbedding(nn.Module):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
positions = positions.flatten() num_tokens = positions.size(0)
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache[positions] cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1) cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape query_shape = query.shape

View File

@@ -8,7 +8,7 @@ from nanovllm.layers.attention import Attention
from nanovllm.layers.layernorm import RMSNorm from nanovllm.layers.layernorm import RMSNorm
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
from nanovllm.layers.rotary_embedding import get_rope from nanovllm.layers.rotary_embedding import get_rope
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
class Qwen3Attention(nn.Module): class Qwen3Attention(nn.Module):
@@ -26,19 +26,17 @@ class Qwen3Attention(nn.Module):
rope_scaling: tuple | None = None, rope_scaling: tuple | None = None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size
tp_size = dist.get_world_size() tp_size = dist.get_world_size()
self.total_num_heads = num_heads self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads self.total_num_kv_heads = num_kv_heads
assert self.total_num_kv_heads % tp_size == 0 assert self.total_num_kv_heads % tp_size == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.num_kv_heads = self.total_num_kv_heads // tp_size
self.head_dim = head_dim or hidden_size // self.total_num_heads self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
@@ -57,13 +55,15 @@ class Qwen3Attention(nn.Module):
self.head_dim, self.head_dim,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=max_position, max_position=max_position,
base=self.rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
) )
self.attn = Attention(self.num_heads, self.attn = Attention(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
num_kv_heads=self.num_kv_heads) self.scaling,
self.num_kv_heads,
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
@@ -122,9 +122,8 @@ class Qwen3DecoderLayer(nn.Module):
config: Qwen3Config, config: Qwen3Config,
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Qwen3Attention( self.self_attn = Qwen3Attention(
hidden_size=self.hidden_size, hidden_size=config.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
@@ -139,10 +138,8 @@ class Qwen3DecoderLayer(nn.Module):
intermediate_size=config.intermediate_size, intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
) )
self.input_layernorm = RMSNorm(config.hidden_size, self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward( def forward(
self, self,
@@ -155,10 +152,7 @@ class Qwen3DecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.self_attn(positions, hidden_states)
positions=positions,
hidden_states=hidden_states,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
return hidden_states, residual return hidden_states, residual
@@ -169,9 +163,8 @@ class Qwen3Model(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3Config, config: Qwen3Config,
): ) -> None:
super().__init__() super().__init__()
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -184,11 +177,7 @@ class Qwen3Model(nn.Module):
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for layer in self.layers: for layer in self.layers:
hidden_states, residual = layer( hidden_states, residual = layer(positions, hidden_states, residual)
positions,
hidden_states,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
@@ -205,12 +194,11 @@ class Qwen3ForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: Qwen3Config config: Qwen3Config
): ) -> None:
super().__init__() super().__init__()
self.model = Qwen3Model(config) self.model = Qwen3Model(config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.tie_word_embeddings = config.tie_word_embeddings if config.tie_word_embeddings:
if self.tie_word_embeddings:
self.lm_head.weight.data = self.model.embed_tokens.weight.data self.lm_head.weight.data = self.model.embed_tokens.weight.data
def forward( def forward(

View File

@@ -1,4 +1,3 @@
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch

View File

@@ -1,18 +0,0 @@
import os
import torch
from pynvml import *
def get_gpu_memory():
torch.cuda.synchronize()
nvmlInit()
visible_device = list(map(int, os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(',')))
cuda_device_idx = torch.cuda.current_device()
cuda_device_idx = visible_device[cuda_device_idx]
handle = nvmlDeviceGetHandleByIndex(cuda_device_idx)
mem_info = nvmlDeviceGetMemoryInfo(handle)
total_memory = mem_info.total
used_memory = mem_info.used
free_memory = mem_info.free
nvmlShutdown()
return total_memory, used_memory, free_memory