support tensor parallel

This commit is contained in:
cheunglei
2025-06-15 01:31:24 +08:00
parent b6136383c9
commit 53b3ef2e32
9 changed files with 102 additions and 31 deletions

View File

@@ -1,5 +1,6 @@
import torch
from torch import nn
import torch.distributed as dist
from transformers import Qwen3Config
from nanovllm.layers.activation import SiluAndMul
@@ -26,7 +27,7 @@ class Qwen3Attention(nn.Module):
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = 1 # get_tensor_model_parallel_world_size()
tp_size = dist.get_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size