simplify
This commit is contained in:
@@ -5,14 +5,6 @@ import numpy as np
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
|
||||
|
||||
def compute_hash(token_ids: list[int], prefix: int = -1):
|
||||
h = xxhash.xxh64()
|
||||
if prefix != -1:
|
||||
h.update(prefix.to_bytes(8, "little"))
|
||||
h.update(np.array(token_ids).tobytes())
|
||||
return h.intdigest()
|
||||
|
||||
|
||||
class Block:
|
||||
|
||||
def __init__(self, block_id):
|
||||
@@ -22,7 +14,6 @@ class Block:
|
||||
self.token_ids = []
|
||||
|
||||
def update(self, hash: int, token_ids: list[int]):
|
||||
assert hash != -1
|
||||
self.hash = hash
|
||||
self.token_ids = token_ids
|
||||
|
||||
@@ -42,7 +33,15 @@ class BlockManager:
|
||||
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
||||
self.used_block_ids: set[int] = set()
|
||||
|
||||
def _allocate_block(self, block_id: int):
|
||||
@classmethod
|
||||
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
|
||||
h = xxhash.xxh64()
|
||||
if prefix != -1:
|
||||
h.update(prefix.to_bytes(8, "little"))
|
||||
h.update(np.array(token_ids).tobytes())
|
||||
return h.intdigest()
|
||||
|
||||
def _allocate_block(self, block_id: int) -> Block:
|
||||
block = self.blocks[block_id]
|
||||
assert block.ref_count == 0
|
||||
block.reset()
|
||||
@@ -50,12 +49,12 @@ class BlockManager:
|
||||
self.used_block_ids.add(block_id)
|
||||
return self.blocks[block_id]
|
||||
|
||||
def _deallocate_block(self, block_id: int):
|
||||
def _deallocate_block(self, block_id: int) -> Block:
|
||||
assert self.blocks[block_id].ref_count == 0
|
||||
self.used_block_ids.remove(block_id)
|
||||
self.free_block_ids.append(block_id)
|
||||
|
||||
def can_allocate(self, seq: Sequence):
|
||||
def can_allocate(self, seq: Sequence) -> bool:
|
||||
return len(self.free_block_ids) >= seq.num_blocks
|
||||
|
||||
def allocate(self, seq: Sequence):
|
||||
@@ -64,7 +63,7 @@ class BlockManager:
|
||||
cache_miss = False
|
||||
for i in range(seq.num_blocks):
|
||||
token_ids = seq.block(i)
|
||||
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
||||
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
||||
block_id = self.hash_to_block_id.get(h, -1)
|
||||
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
||||
cache_miss = True
|
||||
@@ -92,7 +91,7 @@ class BlockManager:
|
||||
seq.num_cached_tokens = 0
|
||||
seq.block_table.clear()
|
||||
|
||||
def can_append(self, seq: Sequence):
|
||||
def can_append(self, seq: Sequence) -> bool:
|
||||
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
||||
|
||||
def may_append(self, seq: Sequence):
|
||||
@@ -107,7 +106,7 @@ class BlockManager:
|
||||
assert last_block.hash == -1
|
||||
token_ids = seq.block(seq.num_blocks-1)
|
||||
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||
h = compute_hash(token_ids, prefix)
|
||||
h = self.compute_hash(token_ids, prefix)
|
||||
last_block.update(h, token_ids)
|
||||
self.hash_to_block_id[h] = last_block.block_id
|
||||
else:
|
||||
|
||||
@@ -6,10 +6,9 @@ from multiprocessing.shared_memory import SharedMemory
|
||||
|
||||
from nanovllm.config import Config
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
from nanovllm.utils.context import set_context, get_context, reset_context
|
||||
from nanovllm.utils.memory import get_gpu_memory
|
||||
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||
from nanovllm.layers.sampler import Sampler
|
||||
from nanovllm.utils.context import set_context, get_context, reset_context
|
||||
from nanovllm.utils.loader import load_model
|
||||
|
||||
|
||||
@@ -93,11 +92,11 @@ class ModelRunner:
|
||||
def allocate_kv_cache(self, gpu_memory_utilization):
|
||||
config = self.config
|
||||
hf_config = config.hf_config
|
||||
total, used, _ = get_gpu_memory()
|
||||
free = total * gpu_memory_utilization - used
|
||||
free, total = torch.cuda.mem_get_info()
|
||||
used = total - free
|
||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||
config.num_kvcache_blocks = int(free) // block_bytes
|
||||
config.num_kvcache_blocks = int(total * gpu_memory_utilization - used) // block_bytes
|
||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
@@ -142,7 +141,6 @@ class ModelRunner:
|
||||
end = start + seq.last_block_num_tokens
|
||||
slot_mapping.extend(list(range(start, end)))
|
||||
assert len(input_ids) == len(slot_mapping)
|
||||
assert len(input_ids) == cu_seqlens_q[-1]
|
||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||
block_tables = self.prepare_block_tables(seqs)
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||
|
||||
@@ -53,10 +53,8 @@ class Scheduler:
|
||||
num_seqs += 1
|
||||
self.block_manager.may_append(seq)
|
||||
scheduled_seqs.append(seq)
|
||||
running = deque(scheduled_seqs)
|
||||
running.extend(self.running)
|
||||
self.running = running
|
||||
assert scheduled_seqs
|
||||
self.running.extendleft(reversed(scheduled_seqs))
|
||||
return scheduled_seqs, False
|
||||
|
||||
def preempt(self, seq: Sequence):
|
||||
|
||||
@@ -21,7 +21,6 @@ class VocabParallelEmbedding(nn.Module):
|
||||
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
||||
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
||||
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
||||
self.embedding_dim = embedding_dim
|
||||
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
|
||||
|
||||
@@ -64,12 +64,6 @@ class ColumnParallelLinear(LinearBase):
|
||||
super().__init__(input_size, output_size, 0)
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
@@ -122,23 +116,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
total_num_kv_heads: int | None = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads or total_num_heads
|
||||
tp_size = dist.get_world_size()
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
]
|
||||
|
||||
input_size = hidden_size
|
||||
output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size
|
||||
super().__init__(input_size, output_size, bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
||||
@@ -170,7 +155,6 @@ class RowParallelLinear(LinearBase):
|
||||
super().__init__(input_size, output_size, 1)
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -28,12 +27,9 @@ class RotaryEmbedding(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
assert rotary_dim == head_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
@@ -47,8 +43,7 @@ class RotaryEmbedding(nn.Module):
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
num_tokens = positions.size(0)
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
|
||||
@@ -26,19 +26,17 @@ class Qwen3Attention(nn.Module):
|
||||
rope_scaling: tuple | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = dist.get_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -57,13 +55,15 @@ class Qwen3Attention(nn.Module):
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=self.rope_theta,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
self.num_kv_heads,
|
||||
)
|
||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||
|
||||
@@ -122,9 +122,8 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
config: Qwen3Config,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Qwen3Attention(
|
||||
hidden_size=self.hidden_size,
|
||||
hidden_size=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
@@ -139,10 +138,8 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
intermediate_size=config.intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -155,10 +152,7 @@ class Qwen3DecoderLayer(nn.Module):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
hidden_states = self.self_attn(positions, hidden_states)
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
return hidden_states, residual
|
||||
@@ -169,9 +163,8 @@ class Qwen3Model(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3Config,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vocab_size = config.vocab_size
|
||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -184,11 +177,7 @@ class Qwen3Model(nn.Module):
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for layer in self.layers:
|
||||
hidden_states, residual = layer(
|
||||
positions,
|
||||
hidden_states,
|
||||
residual,
|
||||
)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
@@ -205,12 +194,11 @@ class Qwen3ForCausalLM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: Qwen3Config
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = Qwen3Model(config)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.tie_word_embeddings = config.tie_word_embeddings
|
||||
if self.tie_word_embeddings:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||
|
||||
def forward(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
import os
|
||||
import torch
|
||||
from pynvml import *
|
||||
|
||||
|
||||
def get_gpu_memory():
|
||||
torch.cuda.synchronize()
|
||||
nvmlInit()
|
||||
visible_device = list(map(int, os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(',')))
|
||||
cuda_device_idx = torch.cuda.current_device()
|
||||
cuda_device_idx = visible_device[cuda_device_idx]
|
||||
handle = nvmlDeviceGetHandleByIndex(cuda_device_idx)
|
||||
mem_info = nvmlDeviceGetMemoryInfo(handle)
|
||||
total_memory = mem_info.total
|
||||
used_memory = mem_info.used
|
||||
free_memory = mem_info.free
|
||||
nvmlShutdown()
|
||||
return total_memory, used_memory, free_memory
|
||||
Reference in New Issue
Block a user