simplify
This commit is contained in:
@@ -18,7 +18,6 @@ def main():
|
|||||||
[{"role": "user", "content": prompt}],
|
[{"role": "user", "content": prompt}],
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
enable_thinking=True
|
|
||||||
)
|
)
|
||||||
for prompt in prompts
|
for prompt in prompts
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ class Block:
|
|||||||
class BlockManager:
|
class BlockManager:
|
||||||
|
|
||||||
def __init__(self, num_blocks: int, block_size: int):
|
def __init__(self, num_blocks: int, block_size: int):
|
||||||
assert num_blocks > 0
|
|
||||||
self.block_size = block_size
|
self.block_size = block_size
|
||||||
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
|
||||||
self.hash_to_block_id: dict[int, int] = dict()
|
self.hash_to_block_id: dict[int, int] = dict()
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ class LLMEngine:
|
|||||||
outputs[seq_id] = token_ids
|
outputs[seq_id] = token_ids
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
outputs = [outputs[seq_id] for seq_id in sorted(outputs)]
|
outputs = [outputs[seq_id] for seq_id in sorted(outputs.keys())]
|
||||||
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs]
|
||||||
if use_tqdm:
|
if use_tqdm:
|
||||||
pbar.close()
|
pbar.close()
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class ModelRunner:
|
|||||||
break
|
break
|
||||||
|
|
||||||
def read_shm(self):
|
def read_shm(self):
|
||||||
assert self.world_size > 1 and self.rank
|
assert self.world_size > 1 and self.rank > 0
|
||||||
self.event.wait()
|
self.event.wait()
|
||||||
n = int.from_bytes(self.shm.buf[0:4], "little")
|
n = int.from_bytes(self.shm.buf[0:4], "little")
|
||||||
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
|
method_name, *args = pickle.loads(self.shm.buf[4:n+4])
|
||||||
@@ -74,7 +74,7 @@ class ModelRunner:
|
|||||||
return method_name, args
|
return method_name, args
|
||||||
|
|
||||||
def write_shm(self, method_name, *args):
|
def write_shm(self, method_name, *args):
|
||||||
assert self.world_size > 1 and not self.rank
|
assert self.world_size > 1 and self.rank == 0
|
||||||
data = pickle.dumps([method_name, *args])
|
data = pickle.dumps([method_name, *args])
|
||||||
n = len(data)
|
n = len(data)
|
||||||
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
self.shm.buf[0:4] = n.to_bytes(4, "little")
|
||||||
@@ -108,7 +108,7 @@ class ModelRunner:
|
|||||||
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize
|
||||||
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
|
||||||
assert config.num_kvcache_blocks > 0
|
assert config.num_kvcache_blocks > 0
|
||||||
self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim)
|
||||||
layer_id = 0
|
layer_id = 0
|
||||||
for module in self.model.modules():
|
for module in self.model.modules():
|
||||||
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||||
@@ -141,7 +141,7 @@ class ModelRunner:
|
|||||||
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k)
|
||||||
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
max_seqlen_q = max(seqlen_q, max_seqlen_q)
|
||||||
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
max_seqlen_k = max(seqlen_k, max_seqlen_k)
|
||||||
if not seq.block_table:
|
if not seq.block_table: # warmup
|
||||||
continue
|
continue
|
||||||
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
for i in range(seq.num_cached_blocks, seq.num_blocks):
|
||||||
start = seq.block_table[i] * self.block_size
|
start = seq.block_table[i] * self.block_size
|
||||||
@@ -194,12 +194,11 @@ class ModelRunner:
|
|||||||
context = get_context()
|
context = get_context()
|
||||||
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
graph = self.graphs[next(x for x in self.graph_bs if x >= bs)]
|
||||||
graph_vars = self.graph_vars
|
graph_vars = self.graph_vars
|
||||||
for k, v in graph_vars.items():
|
|
||||||
if k != "outputs":
|
|
||||||
v.zero_()
|
|
||||||
graph_vars["input_ids"][:bs] = input_ids
|
graph_vars["input_ids"][:bs] = input_ids
|
||||||
graph_vars["positions"][:bs] = positions
|
graph_vars["positions"][:bs] = positions
|
||||||
|
graph_vars["slot_mapping"].fill_(-1)
|
||||||
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
graph_vars["slot_mapping"][:bs] = context.slot_mapping
|
||||||
|
graph_vars["context_lens"].zero_()
|
||||||
graph_vars["context_lens"][:bs] = context.context_lens
|
graph_vars["context_lens"][:bs] = context.context_lens
|
||||||
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
|
graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables
|
||||||
graph.replay()
|
graph.replay()
|
||||||
|
|||||||
@@ -19,11 +19,12 @@ def store_kvcache_kernel(
|
|||||||
D: tl.constexpr,
|
D: tl.constexpr,
|
||||||
):
|
):
|
||||||
idx = tl.program_id(0)
|
idx = tl.program_id(0)
|
||||||
|
slot = tl.load(slot_mapping_ptr + idx)
|
||||||
|
if slot == -1: return
|
||||||
key_offsets = idx * key_stride + tl.arange(0, D)
|
key_offsets = idx * key_stride + tl.arange(0, D)
|
||||||
value_offsets = idx * value_stride + tl.arange(0, D)
|
value_offsets = idx * value_stride + tl.arange(0, D)
|
||||||
key = tl.load(key_ptr + key_offsets)
|
key = tl.load(key_ptr + key_offsets)
|
||||||
value = tl.load(value_ptr + value_offsets)
|
value = tl.load(value_ptr + value_offsets)
|
||||||
slot = tl.load(slot_mapping_ptr + idx)
|
|
||||||
cache_offsets = slot * D + tl.arange(0, D)
|
cache_offsets = slot * D + tl.arange(0, D)
|
||||||
tl.store(k_cache_ptr + cache_offsets, key)
|
tl.store(k_cache_ptr + cache_offsets, key)
|
||||||
tl.store(v_cache_ptr + cache_offsets, value)
|
tl.store(v_cache_ptr + cache_offsets, value)
|
||||||
@@ -56,10 +57,6 @@ class Attention(nn.Module):
|
|||||||
self.k_cache = self.v_cache = torch.tensor([])
|
self.k_cache = self.v_cache = torch.tensor([])
|
||||||
|
|
||||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
o: torch.Tensor
|
|
||||||
q = q.view(-1, self.num_heads, self.head_dim)
|
|
||||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
|
||||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
|
||||||
context = get_context()
|
context = get_context()
|
||||||
k_cache, v_cache = self.k_cache, self.v_cache
|
k_cache, v_cache = self.k_cache, self.v_cache
|
||||||
if k_cache.numel() and v_cache.numel():
|
if k_cache.numel() and v_cache.numel():
|
||||||
@@ -75,5 +72,4 @@ class Attention(nn.Module):
|
|||||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||||
softmax_scale=self.scale, causal=True)
|
softmax_scale=self.scale, causal=True)
|
||||||
o = o.view(-1, self.num_heads * self.head_dim)
|
|
||||||
return o
|
return o
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ class VocabParallelEmbedding(nn.Module):
|
|||||||
shard_size = param_data.size(0)
|
shard_size = param_data.size(0)
|
||||||
start_idx = self.tp_rank * shard_size
|
start_idx = self.tp_rank * shard_size
|
||||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||||
assert param_data.size() == loaded_weight.size()
|
|
||||||
param_data.copy_(loaded_weight)
|
param_data.copy_(loaded_weight)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
@@ -51,19 +50,15 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
|
assert not bias
|
||||||
super().__init__(num_embeddings, embedding_dim)
|
super().__init__(num_embeddings, embedding_dim)
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition))
|
|
||||||
self.bias.weight_loader = self.weight_loader
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
context = get_context()
|
context = get_context()
|
||||||
if context.is_prefill:
|
if context.is_prefill:
|
||||||
last_indices = context.cu_seqlens_q[1:] - 1
|
last_indices = context.cu_seqlens_q[1:] - 1
|
||||||
x = x[last_indices].contiguous()
|
x = x[last_indices].contiguous()
|
||||||
logits = F.linear(x, self.weight, self.bias)
|
logits = F.linear(x, self.weight)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)] if self.tp_rank == 0 else None
|
||||||
dist.gather(logits, all_logits, 0)
|
dist.gather(logits, all_logits, 0)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ class RMSNorm(nn.Module):
|
|||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_size
|
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
|
||||||
@@ -20,7 +19,7 @@ class RMSNorm(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
x = x.to(torch.float32)
|
x = x.float()
|
||||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
x.mul_(torch.rsqrt(var + self.eps))
|
x.mul_(torch.rsqrt(var + self.eps))
|
||||||
x = x.to(orig_dtype).mul_(self.weight)
|
x = x.to(orig_dtype).mul_(self.weight)
|
||||||
@@ -33,7 +32,7 @@ class RMSNorm(nn.Module):
|
|||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
x = x.to(torch.float32).add_(residual.to(torch.float32))
|
x = x.float().add_(residual.float())
|
||||||
residual = x.to(orig_dtype)
|
residual = x.to(orig_dtype)
|
||||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||||
x.mul_(torch.rsqrt(var + self.eps))
|
x.mul_(torch.rsqrt(var + self.eps))
|
||||||
|
|||||||
@@ -15,14 +15,20 @@ class LinearBase(nn.Module):
|
|||||||
self,
|
self,
|
||||||
input_size: int,
|
input_size: int,
|
||||||
output_size: int,
|
output_size: int,
|
||||||
|
bias: bool = False,
|
||||||
tp_dim: int | None = None,
|
tp_dim: int | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_size = input_size
|
|
||||||
self.output_size = output_size
|
|
||||||
self.tp_dim = tp_dim
|
self.tp_dim = tp_dim
|
||||||
self.tp_rank = dist.get_rank()
|
self.tp_rank = dist.get_rank()
|
||||||
self.tp_size = dist.get_world_size()
|
self.tp_size = dist.get_world_size()
|
||||||
|
self.weight = nn.Parameter(torch.empty(output_size, input_size))
|
||||||
|
self.weight.weight_loader = self.weight_loader
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.empty(output_size))
|
||||||
|
self.bias.weight_loader = self.weight_loader
|
||||||
|
else:
|
||||||
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -36,14 +42,7 @@ class ReplicatedLinear(LinearBase):
|
|||||||
output_size: int,
|
output_size: int,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(input_size, output_size)
|
super().__init__(input_size, output_size, bias)
|
||||||
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
|
|
||||||
self.weight.weight_loader = self.weight_loader
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.empty(self.output_size))
|
|
||||||
self.bias.weight_loader = self.weight_loader
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
param.data.copy_(loaded_weight)
|
param.data.copy_(loaded_weight)
|
||||||
@@ -60,17 +59,8 @@ class ColumnParallelLinear(LinearBase):
|
|||||||
output_size: int,
|
output_size: int,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(input_size, output_size, 0)
|
tp_size = dist.get_world_size()
|
||||||
self.input_size_per_partition = input_size
|
super().__init__(input_size, divide(output_size, tp_size), bias, 0)
|
||||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
|
|
||||||
self.weight.weight_loader = self.weight_loader
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))
|
|
||||||
self.bias.weight_loader = self.weight_loader
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
@@ -92,7 +82,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
self.output_sizes = output_sizes
|
self.output_sizes = output_sizes
|
||||||
super().__init__(input_size, sum(output_sizes), bias=bias)
|
super().__init__(input_size, sum(output_sizes), bias)
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
@@ -113,15 +103,13 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|||||||
total_num_kv_heads: int | None = None,
|
total_num_kv_heads: int | None = None,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
self.head_size = head_size
|
|
||||||
self.total_num_heads = total_num_heads
|
|
||||||
self.total_num_kv_heads = total_num_kv_heads or total_num_heads
|
|
||||||
tp_size = dist.get_world_size()
|
tp_size = dist.get_world_size()
|
||||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
total_num_kv_heads = total_num_kv_heads or total_num_heads
|
||||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
self.head_size = head_size
|
||||||
input_size = hidden_size
|
self.num_heads = divide(total_num_heads, tp_size)
|
||||||
output_size = (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_size
|
self.num_kv_heads = divide(total_num_kv_heads, tp_size)
|
||||||
super().__init__(input_size, output_size, bias)
|
output_size = (total_num_heads + 2 * total_num_kv_heads) * self.head_size
|
||||||
|
super().__init__(hidden_size, output_size, bias)
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
@@ -148,17 +136,8 @@ class RowParallelLinear(LinearBase):
|
|||||||
output_size: int,
|
output_size: int,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__(input_size, output_size, 1)
|
tp_size = dist.get_world_size()
|
||||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
super().__init__(divide(input_size, tp_size), output_size, bias, 1)
|
||||||
self.output_size_per_partition = output_size
|
|
||||||
|
|
||||||
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
|
|
||||||
self.weight.weight_loader = self.weight_loader
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.empty(self.output_size))
|
|
||||||
self.bias.weight_loader = self.weight_loader
|
|
||||||
else:
|
|
||||||
self.register_parameter("bias", None)
|
|
||||||
|
|
||||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||||
param_data = param.data
|
param_data = param.data
|
||||||
|
|||||||
@@ -8,9 +8,7 @@ def apply_rotary_emb(
|
|||||||
cos: torch.Tensor,
|
cos: torch.Tensor,
|
||||||
sin: torch.Tensor,
|
sin: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
cos = cos.unsqueeze(-2)
|
x1, x2 = torch.chunk(x.float(), 2, dim=-1)
|
||||||
sin = sin.unsqueeze(-2)
|
|
||||||
x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1)
|
|
||||||
y1 = x1 * cos - x2 * sin
|
y1 = x1 * cos - x2 * sin
|
||||||
y2 = x2 * cos + x1 * sin
|
y2 = x2 * cos + x1 * sin
|
||||||
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
||||||
@@ -33,7 +31,7 @@ class RotaryEmbedding(nn.Module):
|
|||||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||||
cos = freqs.cos()
|
cos = freqs.cos()
|
||||||
sin = freqs.sin()
|
sin = freqs.sin()
|
||||||
cache = torch.cat((cos, sin), dim=-1)
|
cache = torch.cat((cos, sin), dim=-1).unsqueeze_(1)
|
||||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile
|
||||||
@@ -43,15 +41,10 @@ class RotaryEmbedding(nn.Module):
|
|||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
num_tokens = positions.size(0)
|
|
||||||
cos_sin = self.cos_sin_cache[positions]
|
cos_sin = self.cos_sin_cache[positions]
|
||||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||||
query_shape = query.shape
|
query = apply_rotary_emb(query, cos, sin)
|
||||||
query = query.view(num_tokens, -1, self.head_size)
|
key = apply_rotary_emb(key, cos, sin)
|
||||||
query = apply_rotary_emb(query, cos, sin).view(query_shape)
|
|
||||||
key_shape = key.shape
|
|
||||||
key = key.view(num_tokens, -1, self.head_size)
|
|
||||||
key = apply_rotary_emb(key, cos, sin).view(key_shape)
|
|
||||||
return query, key
|
return query, key
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -8,11 +8,9 @@ class Sampler(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor):
|
||||||
logits = logits.to(torch.float)
|
logits = logits.float()
|
||||||
greedy_tokens = logits.argmax(dim=-1)
|
greedy_tokens = logits.argmax(dim=-1)
|
||||||
logits.div_(temperatures.unsqueeze(dim=1))
|
logits.div_(temperatures.unsqueeze(dim=1))
|
||||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||||
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + 1e-10).argmax(dim=-1)
|
||||||
epsilon = 1e-10
|
|
||||||
sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + epsilon).argmax(dim=-1)
|
|
||||||
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
|
return torch.where(temperatures == 0, greedy_tokens, sample_tokens)
|
||||||
|
|||||||
@@ -73,15 +73,12 @@ class Qwen3Attention(nn.Module):
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv = self.qkv_proj(hidden_states)
|
qkv = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||||
q_by_head = q.view(-1, self.num_heads, self.head_dim)
|
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim))
|
||||||
q_by_head = self.q_norm(q_by_head)
|
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim))
|
||||||
q = q_by_head.view(q.shape)
|
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||||
k_by_head = k.view(-1, self.num_kv_heads, self.head_dim)
|
|
||||||
k_by_head = self.k_norm(k_by_head)
|
|
||||||
k = k_by_head.view(k.shape)
|
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q, k = self.rotary_emb(positions, q, k)
|
||||||
o = self.attn(q, k, v)
|
o = self.attn(q, k, v)
|
||||||
output = self.o_proj(o)
|
output = self.o_proj(o.flatten(1, -1))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@@ -147,8 +144,7 @@ class Qwen3DecoderLayer(nn.Module):
|
|||||||
residual: torch.Tensor | None,
|
residual: torch.Tensor | None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
if residual is None:
|
if residual is None:
|
||||||
residual = hidden_states
|
hidden_states, residual = self.input_layernorm(hidden_states), hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
|
||||||
else:
|
else:
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
hidden_states = self.self_attn(positions, hidden_states)
|
hidden_states = self.self_attn(positions, hidden_states)
|
||||||
@@ -205,12 +201,10 @@ class Qwen3ForCausalLM(nn.Module):
|
|||||||
input_ids: torch.Tensor,
|
input_ids: torch.Tensor,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions)
|
return self.model(input_ids, positions)
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
logits = self.lm_head(hidden_states)
|
return self.lm_head(hidden_states)
|
||||||
return logits
|
|
||||||
|
|||||||
Reference in New Issue
Block a user