This commit is contained in:
GeeeekExplorer
2025-08-31 19:44:57 +08:00
parent 6a6d217de7
commit df99418f7d
11 changed files with 47 additions and 96 deletions

View File

@@ -19,11 +19,12 @@ def store_kvcache_kernel(
D: tl.constexpr,
):
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1: return
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets)
value = tl.load(value_ptr + value_offsets)
slot = tl.load(slot_mapping_ptr + idx)
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key)
tl.store(v_cache_ptr + cache_offsets, value)
@@ -56,10 +57,6 @@ class Attention(nn.Module):
self.k_cache = self.v_cache = torch.tensor([])
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
o: torch.Tensor
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim)
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
@@ -75,5 +72,4 @@ class Attention(nn.Module):
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
o = o.view(-1, self.num_heads * self.head_dim)
return o

View File

@@ -29,7 +29,6 @@ class VocabParallelEmbedding(nn.Module):
shard_size = param_data.size(0)
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
assert param_data.size() == loaded_weight.size()
param_data.copy_(loaded_weight)
def forward(self, x: torch.Tensor):
@@ -51,19 +50,15 @@ class ParallelLMHead(VocabParallelEmbedding):
embedding_dim: int,
bias: bool = False,
):
assert not bias
super().__init__(num_embeddings, embedding_dim)
if bias:
self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor):
context = get_context()
if context.is_prefill:
last_indices = context.cu_seqlens_q[1:] - 1
x = x[last_indices].contiguous()
logits = F.linear(x, self.weight, self.bias)
logits = F.linear(x, self.weight)
if self.tp_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
dist.gather(logits, all_logits, 0)

View File

@@ -10,7 +10,6 @@ class RMSNorm(nn.Module):
eps: float = 1e-6,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
@@ -20,7 +19,7 @@ class RMSNorm(nn.Module):
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x.float()
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))
x = x.to(orig_dtype).mul_(self.weight)
@@ -33,7 +32,7 @@ class RMSNorm(nn.Module):
residual: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
orig_dtype = x.dtype
x = x.to(torch.float32).add_(residual.to(torch.float32))
x = x.float().add_(residual.float())
residual = x.to(orig_dtype)
var = x.pow(2).mean(dim=-1, keepdim=True)
x.mul_(torch.rsqrt(var + self.eps))

View File

@@ -15,14 +15,20 @@ class LinearBase(nn.Module):
self,
input_size: int,
output_size: int,
bias: bool = False,
tp_dim: int | None = None,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.tp_dim = tp_dim
self.tp_rank = dist.get_rank()
self.tp_size = dist.get_world_size()
self.weight = nn.Parameter(torch.empty(output_size, input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
@@ -36,14 +42,7 @@ class ReplicatedLinear(LinearBase):
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size)
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(self.output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
super().__init__(input_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param.data.copy_(loaded_weight)
@@ -60,17 +59,8 @@ class ColumnParallelLinear(LinearBase):
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, 0)
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
tp_size = dist.get_world_size()
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data
@@ -92,7 +82,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias: bool = False,
):
self.output_sizes = output_sizes
super().__init__(input_size, sum(output_sizes), bias=bias)
super().__init__(input_size, sum(output_sizes), bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
param_data = param.data
@@ -113,15 +103,13 @@ class QKVParallelLinear(ColumnParallelLinear):
total_num_kv_heads: int | None = None,
bias: bool = False,
):
self.head_size = head_size
self.total_num_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads or total_num_heads
tp_size = dist.get_world_size()
self.num_heads = divide(self.total_num_heads, tp_size)
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
input_size = hidden_size
output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size
super().__init__(input_size, output_size, bias)
total_num_kv_heads = total_num_kv_heads or total_num_heads
self.head_size = head_size
self.num_heads = divide(total_num_heads, tp_size)
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
super().__init__(hidden_size, output_size, bias)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
param_data = param.data
@@ -148,17 +136,8 @@ class RowParallelLinear(LinearBase):
output_size: int,
bias: bool = False,
):
super().__init__(input_size, output_size, 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
self.weight.weight_loader = self.weight_loader
if bias:
self.bias = nn.Parameter(torch.empty(self.output_size))
self.bias.weight_loader = self.weight_loader
else:
self.register_parameter("bias", None)
tp_size = dist.get_world_size()
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
param_data = param.data

View File

@@ -8,9 +8,7 @@ def apply_rotary_emb(
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
cos = cos.unsqueeze(-2)
sin = sin.unsqueeze(-2)
x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1)
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
y1 = x1 * cos - x2 * sin
y2 = x2 * cos + x1 * sin
return torch.cat((y1, y2), dim=-1).to(x.dtype)
@@ -33,7 +31,7 @@ class RotaryEmbedding(nn.Module):
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
self.register_buffer("cos_sin_cache", cache, persistent=False)
@torch.compile
@@ -43,15 +41,10 @@ class RotaryEmbedding(nn.Module):
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
num_tokens = positions.size(0)
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query = apply_rotary_emb(query, cos, sin).view(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key = apply_rotary_emb(key, cos, sin).view(key_shape)
query = apply_rotary_emb(query, cos, sin)
key = apply_rotary_emb(key, cos, sin)
return query, key

View File

@@ -8,11 +8,9 @@ class Sampler(nn.Module):
super().__init__()
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
logits = logits.to(torch.float)
logits = logits.float()
greedy_tokens = logits.argmax(dim=-1)
logits.div_(temperatures.unsqueeze(dim=1))
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
epsilon = 1e-10
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + epsilon).argmax(dim=-1)
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + 1e-10).argmax(dim=-1)
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)