support tensor parallel
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.distributed as dist
|
||||
from transformers import Qwen3Config
|
||||
|
||||
from nanovllm.layers.activation import SiluAndMul
|
||||
@@ -26,7 +27,7 @@ class Qwen3Attention(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||
tp_size = dist.get_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
|
||||
Reference in New Issue
Block a user