remove assert shape
This commit is contained in:
@@ -31,7 +31,7 @@ class RMSNorm(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32).add_(residual.to(torch.float32))
|
||||
residual = x.to(orig_dtype)
|
||||
|
||||
@@ -46,7 +46,6 @@ class ReplicatedLinear(LinearBase):
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -78,7 +77,6 @@ class ColumnParallelLinear(LinearBase):
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -102,7 +100,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@@ -140,7 +137,6 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
loaded_weight = loaded_weight.chunk(self.tp_size, self.tp_dim)[self.tp_rank]
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
@@ -169,7 +165,6 @@ class RowParallelLinear(LinearBase):
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@@ -50,7 +50,6 @@ class Qwen3Attention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
|
||||
Reference in New Issue
Block a user