remove assert shape

This commit is contained in:
GeeeekExplorer
2025-06-27 23:00:30 +08:00
parent 2de882a395
commit 38baf0bbe4
3 changed files with 1 additions and 7 deletions

View File

@@ -46,7 +46,6 @@ class ReplicatedLinear(LinearBase):
self.register_parameter("bias", None)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -78,7 +77,6 @@ class ColumnParallelLinear(LinearBase):
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -102,7 +100,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)
@@ -140,7 +137,6 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)
@@ -169,7 +165,6 @@ class RowParallelLinear(LinearBase):
shard_size = param_data.size(self.tp_dim)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor: