simplify
This commit is contained in:
@@ -5,14 +5,6 @@ import numpy as np
|
|||||||
from nanovllm.engine.sequence import Sequence
|
from nanovllm.engine.sequence import Sequence
|
||||||
|
|
||||||
|
|
||||||
def compute_hash(token_ids: list[int], prefix: int = -1):
|
|
||||||
h = xxhash.xxh64()
|
|
||||||
if prefix != -1:
|
|
||||||
h.update(prefix.to_bytes(8, "little"))
|
|
||||||
h.update(np.array(token_ids).tobytes())
|
|
||||||
return h.intdigest()
|
|
||||||
|
|
||||||
|
|
||||||
class Block:
|
class Block:
|
||||||
|
|
||||||
def __init__(self, block_id):
|
def __init__(self, block_id):
|
||||||
@@ -22,7 +14,6 @@ class Block:
|
|||||||
self.token_ids = []
|
self.token_ids = []
|
||||||
|
|
||||||
def update(self, hash: int, token_ids: list[int]):
|
def update(self, hash: int, token_ids: list[int]):
|
||||||
assert hash != -1
|
|
||||||
self.hash = hash
|
self.hash = hash
|
||||||
self.token_ids = token_ids
|
self.token_ids = token_ids
|
||||||
|
|
||||||
@@ -42,7 +33,15 @@ class BlockManager:
|
|||||||
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
self.free_block_ids: deque[int] = deque(range(num_blocks))
|
||||||
self.used_block_ids: set[int] = set()
|
self.used_block_ids: set[int] = set()
|
||||||
|
|
||||||
def _allocate_block(self, block_id: int):
|
@classmethod
|
||||||
|
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
|
||||||
|
h = xxhash.xxh64()
|
||||||
|
if prefix != -1:
|
||||||
|
h.update(prefix.to_bytes(8, "little"))
|
||||||
|
h.update(np.array(token_ids).tobytes())
|
||||||
|
return h.intdigest()
|
||||||
|
|
||||||
|
def _allocate_block(self, block_id: int) -> Block:
|
||||||
block = self.blocks[block_id]
|
block = self.blocks[block_id]
|
||||||
assert block.ref_count == 0
|
assert block.ref_count == 0
|
||||||
block.reset()
|
block.reset()
|
||||||
@@ -50,12 +49,12 @@ class BlockManager:
|
|||||||
self.used_block_ids.add(block_id)
|
self.used_block_ids.add(block_id)
|
||||||
return self.blocks[block_id]
|
return self.blocks[block_id]
|
||||||
|
|
||||||
def _deallocate_block(self, block_id: int):
|
def _deallocate_block(self, block_id: int) -> Block:
|
||||||
assert self.blocks[block_id].ref_count == 0
|
assert self.blocks[block_id].ref_count == 0
|
||||||
self.used_block_ids.remove(block_id)
|
self.used_block_ids.remove(block_id)
|
||||||
self.free_block_ids.append(block_id)
|
self.free_block_ids.append(block_id)
|
||||||
|
|
||||||
def can_allocate(self, seq: Sequence):
|
def can_allocate(self, seq: Sequence) -> bool:
|
||||||
return len(self.free_block_ids) >= seq.num_blocks
|
return len(self.free_block_ids) >= seq.num_blocks
|
||||||
|
|
||||||
def allocate(self, seq: Sequence):
|
def allocate(self, seq: Sequence):
|
||||||
@@ -64,7 +63,7 @@ class BlockManager:
|
|||||||
cache_miss = False
|
cache_miss = False
|
||||||
for i in range(seq.num_blocks):
|
for i in range(seq.num_blocks):
|
||||||
token_ids = seq.block(i)
|
token_ids = seq.block(i)
|
||||||
h = compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1
|
||||||
block_id = self.hash_to_block_id.get(h, -1)
|
block_id = self.hash_to_block_id.get(h, -1)
|
||||||
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
if block_id == -1 or self.blocks[block_id].token_ids != token_ids:
|
||||||
cache_miss = True
|
cache_miss = True
|
||||||
@@ -92,7 +91,7 @@ class BlockManager:
|
|||||||
seq.num_cached_tokens = 0
|
seq.num_cached_tokens = 0
|
||||||
seq.block_table.clear()
|
seq.block_table.clear()
|
||||||
|
|
||||||
def can_append(self, seq: Sequence):
|
def can_append(self, seq: Sequence) -> bool:
|
||||||
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)
|
||||||
|
|
||||||
def may_append(self, seq: Sequence):
|
def may_append(self, seq: Sequence):
|
||||||
@@ -107,7 +106,7 @@ class BlockManager:
|
|||||||
assert last_block.hash == -1
|
assert last_block.hash == -1
|
||||||
token_ids = seq.block(seq.num_blocks-1)
|
token_ids = seq.block(seq.num_blocks-1)
|
||||||
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1
|
||||||
h = compute_hash(token_ids, prefix)
|
h = self.compute_hash(token_ids, prefix)
|
||||||
last_block.update(h, token_ids)
|
last_block.update(h, token_ids)
|
||||||
self.hash_to_block_id[h] = last_block.block_id
|
self.hash_to_block_id[h] = last_block.block_id
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -6,10 +6,9 @@ from multiprocessing.shared_memory import SharedMemory
|
|||||||
|
|
||||||
from nanovllm.config import Config
|
from nanovllm.config import Config
|
||||||
from nanovllm.engine.sequence import Sequence
|
from nanovllm.engine.sequence import Sequence
|
||||||
from nanovllm.utils.context import set_context, get_context, reset_context
|
|
||||||
from nanovllm.utils.memory import get_gpu_memory
|
|
||||||
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
from nanovllm.models.qwen3 import Qwen3ForCausalLM
|
||||||
from nanovllm.layers.sampler import Sampler
|
from nanovllm.layers.sampler import Sampler
|
||||||
|
from nanovllm.utils.context import set_context, get_context, reset_context
|
||||||
from nanovllm.utils.loader import load_model
|
from nanovllm.utils.loader import load_model
|
||||||
|
|
||||||
|
|
||||||
@@ -93,11 +92,11 @@ class ModelRunner:
|
|||||||
def allocate_kv_cache(self, gpu_memory_utilization):
|
def allocate_kv_cache(self, gpu_memory_utilization):
|
||||||
config = self.config
|
config = self.config
|
||||||
hf_config = config.hf_config
|
hf_config = config.hf_config
|
||||||
total, used, _ = get_gpu_memory()
|
free, total = torch.cuda.mem_get_info()
|
||||||
free = total * gpu_memory_utilization - used
|
used = total - free
|
||||||
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
num_kv_heads = hf_config.num_key_value_heads // self.world_size
|
||||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||||
config.num_kvcache_blocks = int(free) // block_bytes
|
config.num_kvcache_blocks = int(total * gpu_memory_utilization - used) // block_bytes
|
||||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||||
layer_id = 0
|
layer_id = 0
|
||||||
for module in self.model.modules():
|
for module in self.model.modules():
|
||||||
@@ -142,7 +141,6 @@ class ModelRunner:
|
|||||||
end = start + seq.last_block_num_tokens
|
end = start + seq.last_block_num_tokens
|
||||||
slot_mapping.extend(list(range(start, end)))
|
slot_mapping.extend(list(range(start, end)))
|
||||||
assert len(input_ids) == len(slot_mapping)
|
assert len(input_ids) == len(slot_mapping)
|
||||||
assert len(input_ids) == cu_seqlens_q[-1]
|
|
||||||
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache
|
||||||
block_tables = self.prepare_block_tables(seqs)
|
block_tables = self.prepare_block_tables(seqs)
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True)
|
||||||
|
|||||||
@@ -53,10 +53,8 @@ class Scheduler:
|
|||||||
num_seqs += 1
|
num_seqs += 1
|
||||||
self.block_manager.may_append(seq)
|
self.block_manager.may_append(seq)
|
||||||
scheduled_seqs.append(seq)
|
scheduled_seqs.append(seq)
|
||||||
running = deque(scheduled_seqs)
|
|
||||||
running.extend(self.running)
|
|
||||||
self.running = running
|
|
||||||
assert scheduled_seqs
|
assert scheduled_seqs
|
||||||
|
self.running.extendleft(reversed(scheduled_seqs))
|
||||||
return scheduled_seqs, False
|
return scheduled_seqs, False
|
||||||
|
|
||||||
def preempt(self, seq: Sequence):
|
def preempt(self, seq: Sequence):
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ class VocabParallelEmbedding(nn.Module):
|
|||||||
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
||||||
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
||||||
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
||||||
self.embedding_dim = embedding_dim
|
|
||||||
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
||||||
self.weight.weight_loader = self.weight_loader
|
self.weight.weight_loader = self.weight_loader
|
||||||
|
|
||||||
|
|||||||
@@ -64,12 +64,6 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
super().__init__(input_size, output_size, 0)
|
super().__init__(input_size, output_size, 0)
|
||||||
self.input_size_per_partition = input_size
|
self.input_size_per_partition = input_size
|
||||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||||
self.output_partition_sizes = [self.output_size_per_partition]
|
|
||||||
if hasattr(self, "output_sizes"):
|
|
||||||
self.output_partition_sizes = [
|
|
||||||
divide(output_size, self.tp_size)
|
|
||||||
for output_size in self.output_sizes
|
|
||||||
]
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
|
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
|
||||||
self.weight.weight_loader = self.weight_loader
|
self.weight.weight_loader = self.weight_loader
|
||||||
@@ -122,23 +116,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
total_num_kv_heads: int | None = None,
|
total_num_kv_heads: int | None = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.total_num_heads = total_num_heads
|
self.total_num_heads = total_num_heads
|
||||||
if total_num_kv_heads is None:
|
self.total_num_kv_heads = total_num_kv_heads or total_num_heads
|
||||||
total_num_kv_heads = total_num_heads
|
|
||||||
self.total_num_kv_heads = total_num_kv_heads
|
|
||||||
tp_size = dist.get_world_size()
|
tp_size = dist.get_world_size()
|
||||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||||
input_size = self.hidden_size
|
input_size = hidden_size
|
||||||
output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size
|
||||||
self.output_sizes = [
|
|
||||||
self.num_heads * self.head_size * tp_size, # q_proj
|
|
||||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
|
||||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
|
||||||
]
|
|
||||||
|
|
||||||
super().__init__(input_size, output_size, bias)
|
super().__init__(input_size, output_size, bias)
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
||||||
@@ -170,7 +155,6 @@ class RowParallelLinear(LinearBase):
|
|||||||
super().__init__(input_size, output_size, 1)
|
super().__init__(input_size, output_size, 1)
|
||||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||||
self.output_size_per_partition = output_size
|
self.output_size_per_partition = output_size
|
||||||
self.output_partition_sizes = [output_size]
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
|
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
|
||||||
self.weight.weight_loader = self.weight_loader
|
self.weight.weight_loader = self.weight_loader
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@@ -28,12 +27,9 @@ class RotaryEmbedding(nn.Module):
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head_size = head_size
|
self.head_size = head_size
|
||||||
self.rotary_dim = rotary_dim
|
|
||||||
assert rotary_dim == head_size
|
assert rotary_dim == head_size
|
||||||
self.max_position_embeddings = max_position_embeddings
|
inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
|
||||||
self.base = base
|
t = torch.arange(max_position_embeddings, dtype=torch.float)
|
||||||
inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
|
||||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
|
||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
@@ -47,8 +43,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
positions = positions.flatten()
|
num_tokens = positions.size(0)
|
||||||
num_tokens = positions.shape[0]
|
|
||||||
cos_sin = self.cos_sin_cache[positions]
|
cos_sin = self.cos_sin_cache[positions]
|
||||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
query_shape = query.shape
|
query_shape = query.shape
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from nanovllm.layers.attention import Attention
|
|||||||
from nanovllm.layers.layernorm import RMSNorm
|
from nanovllm.layers.layernorm import RMSNorm
|
||||||
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
||||||
from nanovllm.layers.rotary_embedding import get_rope
|
from nanovllm.layers.rotary_embedding import get_rope
|
||||||
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
||||||
|
|
||||||
|
|
||||||
class Qwen3Attention(nn.Module):
|
class Qwen3Attention(nn.Module):
|
||||||
@@ -26,19 +26,17 @@ class Qwen3Attention(nn.Module):
|
|||||||
rope_scaling: tuple | None = None,
|
rope_scaling: tuple | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
|
||||||
tp_size = dist.get_world_size()
|
tp_size = dist.get_world_size()
|
||||||
self.total_num_heads = num_heads
|
self.total_num_heads = num_heads
|
||||||
assert self.total_num_heads % tp_size == 0
|
assert self.total_num_heads % tp_size == 0
|
||||||
self.num_heads = self.total_num_heads // tp_size
|
self.num_heads = self.total_num_heads // tp_size
|
||||||
self.total_num_kv_heads = num_kv_heads
|
self.total_num_kv_heads = num_kv_heads
|
||||||
assert self.total_num_kv_heads % tp_size == 0
|
assert self.total_num_kv_heads % tp_size == 0
|
||||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
self.num_kv_heads = self.total_num_kv_heads // tp_size
|
||||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||||
self.q_size = self.num_heads * self.head_dim
|
self.q_size = self.num_heads * self.head_dim
|
||||||
self.kv_size = self.num_kv_heads * self.head_dim
|
self.kv_size = self.num_kv_heads * self.head_dim
|
||||||
self.scaling = self.head_dim**-0.5
|
self.scaling = self.head_dim**-0.5
|
||||||
self.rope_theta = rope_theta
|
|
||||||
|
|
||||||
self.qkv_proj = QKVParallelLinear(
|
self.qkv_proj = QKVParallelLinear(
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@@ -57,13 +55,15 @@ class Qwen3Attention(nn.Module):
|
|||||||
self.head_dim,
|
self.head_dim,
|
||||||
rotary_dim=self.head_dim,
|
rotary_dim=self.head_dim,
|
||||||
max_position=max_position,
|
max_position=max_position,
|
||||||
base=self.rope_theta,
|
base=rope_theta,
|
||||||
rope_scaling=rope_scaling,
|
rope_scaling=rope_scaling,
|
||||||
)
|
)
|
||||||
self.attn = Attention(self.num_heads,
|
self.attn = Attention(
|
||||||
self.head_dim,
|
self.num_heads,
|
||||||
self.scaling,
|
self.head_dim,
|
||||||
num_kv_heads=self.num_kv_heads)
|
self.scaling,
|
||||||
|
self.num_kv_heads,
|
||||||
|
)
|
||||||
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||||
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
|
||||||
|
|
||||||
@@ -122,9 +122,8 @@ class Qwen3DecoderLayer(nn.Module):
|
|||||||
config: Qwen3Config,
|
config: Qwen3Config,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
|
||||||
self.self_attn = Qwen3Attention(
|
self.self_attn = Qwen3Attention(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=config.hidden_size,
|
||||||
num_heads=config.num_attention_heads,
|
num_heads=config.num_attention_heads,
|
||||||
num_kv_heads=config.num_key_value_heads,
|
num_kv_heads=config.num_key_value_heads,
|
||||||
max_position=config.max_position_embeddings,
|
max_position=config.max_position_embeddings,
|
||||||
@@ -139,10 +138,8 @@ class Qwen3DecoderLayer(nn.Module):
|
|||||||
intermediate_size=config.intermediate_size,
|
intermediate_size=config.intermediate_size,
|
||||||
hidden_act=config.hidden_act,
|
hidden_act=config.hidden_act,
|
||||||
)
|
)
|
||||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
eps=config.rms_norm_eps)
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
|
||||||
eps=config.rms_norm_eps)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -155,10 +152,7 @@ class Qwen3DecoderLayer(nn.Module):
|
|||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
else:
|
else:
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
hidden_states = self.self_attn(
|
hidden_states = self.self_attn(positions, hidden_states)
|
||||||
positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
)
|
|
||||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||||
hidden_states = self.mlp(hidden_states)
|
hidden_states = self.mlp(hidden_states)
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
@@ -169,9 +163,8 @@ class Qwen3Model(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Qwen3Config,
|
config: Qwen3Config,
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vocab_size = config.vocab_size
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size)
|
||||||
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
self.layers = nn.ModuleList([Qwen3DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
@@ -184,11 +177,7 @@ class Qwen3Model(nn.Module):
|
|||||||
hidden_states = self.embed_tokens(input_ids)
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
residual = None
|
residual = None
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||||
positions,
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
)
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -205,12 +194,11 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Qwen3Config
|
config: Qwen3Config
|
||||||
):
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = Qwen3Model(config)
|
self.model = Qwen3Model(config)
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||||
self.tie_word_embeddings = config.tie_word_embeddings
|
if config.tie_word_embeddings:
|
||||||
if self.tie_word_embeddings:
|
|
||||||
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
self.lm_head.weight.data = self.model.embed_tokens.weight.data
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
import os
|
|
||||||
import torch
|
|
||||||
from pynvml import *
|
|
||||||
|
|
||||||
|
|
||||||
def get_gpu_memory():
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
nvmlInit()
|
|
||||||
visible_device = list(map(int, os.getenv("CUDA_VISIBLE_DEVICES", "0,1,2,3,4,5,6,7").split(',')))
|
|
||||||
cuda_device_idx = torch.cuda.current_device()
|
|
||||||
cuda_device_idx = visible_device[cuda_device_idx]
|
|
||||||
handle = nvmlDeviceGetHandleByIndex(cuda_device_idx)
|
|
||||||
mem_info = nvmlDeviceGetMemoryInfo(handle)
|
|
||||||
total_memory = mem_info.total
|
|
||||||
used_memory = mem_info.used
|
|
||||||
free_memory = mem_info.free
|
|
||||||
nvmlShutdown()
|
|
||||||
return total_memory, used_memory, free_memory
|
|
||||||
Reference in New Issue
Block a user