simplify
This commit is contained in:
@@ -29,7 +29,6 @@ class VocabParallelEmbedding(nn.Module):
|
||||
shard_size = param_data.size(0)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
@@ -51,19 +50,15 @@ class ParallelLMHead(VocabParallelEmbedding):
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
assert not bias
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition))
|
||||
self.bias.weight_loader = self.weight_loader
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
context = get_context()
|
||||
if context.is_prefill:
|
||||
last_indices = context.cu_seqlens_q[1:] - 1
|
||||
x = x[last_indices].contiguous()
|
||||
logits = F.linear(x, self.weight, self.bias)
|
||||
logits = F.linear(x, self.weight)
|
||||
if self.tp_size > 1:
|
||||
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
||||
dist.gather(logits, all_logits, 0)
|
||||
|
||||
Reference in New Issue
Block a user