73 lines
2.6 KiB
Python
73 lines
2.6 KiB
Python
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
import torch.distributed as dist
|
|
|
|
from nanovllm.utils.context import get_context
|
|
|
|
|
|
class VocabParallelEmbedding(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
):
|
|
super().__init__()
|
|
self.tp_rank = dist.get_rank()
|
|
self.tp_size = dist.get_world_size()
|
|
assert num_embeddings % self.tp_size == 0
|
|
self.num_embeddings = num_embeddings
|
|
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
|
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
|
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
|
self.embedding_dim = embedding_dim
|
|
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
|
self.weight.weight_loader = self.weight_loader
|
|
|
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
|
param_data = param.data
|
|
shard_size = param_data.size(0)
|
|
start_idx = self.tp_rank * shard_size
|
|
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
|
assert param_data.size() == loaded_weight.size()
|
|
param_data.copy_(loaded_weight)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if self.tp_size > 1:
|
|
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
|
|
x = mask * (x - self.vocab_start_idx)
|
|
y = F.embedding(x, self.weight)
|
|
if self.tp_size > 1:
|
|
y = mask.unsqueeze(1) * y
|
|
dist.all_reduce(y)
|
|
return y
|
|
|
|
|
|
class ParallelLMHead(VocabParallelEmbedding):
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings: int,
|
|
embedding_dim: int,
|
|
bias: bool = False,
|
|
):
|
|
super().__init__(num_embeddings, embedding_dim)
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition))
|
|
self.bias.weight_loader = self.weight_loader
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
context = get_context()
|
|
if context.is_prefill:
|
|
last_indices = context.cu_seqlens_q[1:] - 1
|
|
x = x[last_indices].contiguous()
|
|
logits = F.linear(x, self.weight, self.bias)
|
|
if self.tp_size > 1:
|
|
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
|
dist.gather(logits, all_logits, 0)
|
|
logits = torch.cat(all_logits, -1) if self.tp_rank == 0 else None
|
|
return logits
|