init commit
This commit is contained in:
14
nanovllm/layers/activation.py
Executable file
14
nanovllm/layers/activation.py
Executable file
@@ -0,0 +1,14 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@torch.compile
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x, y = x.chunk(2, -1)
|
||||
return F.silu(x) * y
|
||||
86
nanovllm/layers/attention.py
Normal file
86
nanovllm/layers/attention.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm_flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
# from nanovllm.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
|
||||
@triton.jit
|
||||
def store_kvcache_kernel(
|
||||
key_ptr,
|
||||
key_stride,
|
||||
value_ptr,
|
||||
value_stride,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
slot_mapping_ptr,
|
||||
D: tl.constexpr,
|
||||
):
|
||||
idx = tl.program_id(0)
|
||||
key_offsets = idx * key_stride + tl.arange(0, D)
|
||||
value_offsets = idx * value_stride + tl.arange(0, D)
|
||||
key = tl.load(key_ptr + key_offsets)
|
||||
value = tl.load(value_ptr + value_offsets)
|
||||
slot = tl.load(slot_mapping_ptr + idx)
|
||||
cache_offsets = slot * D + tl.arange(0, D)
|
||||
tl.store(k_cache_ptr + cache_offsets, key)
|
||||
tl.store(v_cache_ptr + cache_offsets, value)
|
||||
|
||||
|
||||
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
||||
N, num_heads, head_dim = key.shape
|
||||
D = num_heads * head_dim
|
||||
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
||||
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
||||
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
||||
assert slot_mapping.numel() == N
|
||||
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
head_dim,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.scale = scale
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.k_cache = self.v_cache = torch.tensor([])
|
||||
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
o: torch.Tensor
|
||||
q = q.view(-1, self.num_heads, self.head_dim)
|
||||
k = k.view(-1, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(-1, self.num_kv_heads, self.head_dim)
|
||||
context = get_context()
|
||||
k_cache = self.k_cache
|
||||
v_cache = self.v_cache
|
||||
if context.is_prefill:
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
if context.block_tables is None: # normal prefill
|
||||
cu_seqlens_k = context.cu_seqlens_k
|
||||
seqused_k = None
|
||||
else: # prefix cache
|
||||
cu_seqlens_k = None
|
||||
seqused_k = context.context_lens
|
||||
k, v = k_cache, v_cache
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=cu_seqlens_k,
|
||||
seqused_k=seqused_k, softmax_scale=self.scale,
|
||||
causal=True, block_table=context.block_tables)
|
||||
else: # decode
|
||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache, k.unsqueeze(1), v.unsqueeze(1),
|
||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||
softmax_scale=self.scale, causal=True)
|
||||
o = o.view(-1, self.num_heads * self.head_dim)
|
||||
return o
|
||||
72
nanovllm/layers/embed_head.py
Normal file
72
nanovllm/layers/embed_head.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
|
||||
class VocabParallelEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_rank = 0 # get_tensor_model_parallel_rank()
|
||||
self.tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||
assert num_embeddings % self.tp_size == 0
|
||||
self.num_embeddings = num_embeddings
|
||||
self.num_embeddings_per_partition = self.num_embeddings // self.tp_size
|
||||
self.vocab_start_idx = self.num_embeddings_per_partition * self.tp_rank
|
||||
self.vocab_end_idx = self.vocab_start_idx + self.num_embeddings_per_partition
|
||||
self.embedding_dim = embedding_dim
|
||||
self.weight = nn.Parameter(torch.empty(self.num_embeddings_per_partition, embedding_dim))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(0)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
if self.tp_size > 1:
|
||||
mask = (x >= self.vocab_start_idx) & (x < self.vocab_end_idx)
|
||||
x = mask * (x - self.vocab_start_idx)
|
||||
y = F.embedding(x, self.weight)
|
||||
if self.tp_size > 1:
|
||||
y = mask * y
|
||||
dist.all_reduce(y)
|
||||
return y
|
||||
|
||||
|
||||
class ParallelLMHead(VocabParallelEmbedding):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__(num_embeddings, embedding_dim)
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.num_embeddings_per_partition))
|
||||
self.bias.weight_loader = self.weight_loader
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
context = get_context()
|
||||
if context.is_prefill:
|
||||
last_indices = context.cu_seqlens_q[1:] - 1
|
||||
x = x[last_indices].contiguous()
|
||||
logits = F.linear(x, self.weight, self.bias)
|
||||
# if self.tp_size > 1:
|
||||
# all_logits = [torch.empty_like(logits) for _ in range(self.tp_size)]
|
||||
# dist.gather(logits, all_logits, 0)
|
||||
# logits = torch.cat(all_logits, -1)
|
||||
return logits if self.tp_rank == 0 else None
|
||||
51
nanovllm/layers/layernorm.py
Executable file
51
nanovllm/layers/layernorm.py
Executable file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
|
||||
@torch.compile
|
||||
def rms_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32)
|
||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x.mul_(torch.rsqrt(var + self.eps))
|
||||
x = x.to(orig_dtype).mul_(self.weight)
|
||||
return x
|
||||
|
||||
@torch.compile
|
||||
def add_rms_forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
orig_dtype = x.dtype
|
||||
x = x.to(torch.float32).add_(residual.to(torch.float32))
|
||||
residual = x.to(orig_dtype)
|
||||
var = x.pow(2).mean(dim=-1, keepdim=True)
|
||||
x.mul_(torch.rsqrt(var + self.eps))
|
||||
x = x.to(orig_dtype).mul_(self.weight)
|
||||
return x, residual
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
|
||||
if residual is None:
|
||||
return self.rms_forward(x)
|
||||
else:
|
||||
return self.add_rms_forward(x, residual)
|
||||
199
nanovllm/layers/linear.py
Executable file
199
nanovllm/layers/linear.py
Executable file
@@ -0,0 +1,199 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def divide(numerator, denominator):
|
||||
assert numerator % denominator == 0
|
||||
return numerator // denominator
|
||||
|
||||
|
||||
class LinearBase(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
tp_dim: int | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.tp_dim = tp_dim
|
||||
self.tp_rank = 0 # get_tensor_model_parallel_rank()
|
||||
self.tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ReplicatedLinear(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__(input_size, output_size)
|
||||
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.output_size))
|
||||
self.bias.weight_loader = self.weight_loader
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class ColumnParallelLinear(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__(input_size, output_size, 0)
|
||||
self.input_size_per_partition = input_size
|
||||
self.output_size_per_partition = divide(output_size, self.tp_size)
|
||||
self.output_partition_sizes = [self.output_size_per_partition]
|
||||
# If QKV or MergedColumn, use output size of each partition.
|
||||
if hasattr(self, "output_sizes"):
|
||||
self.output_partition_sizes = [
|
||||
divide(output_size, self.tp_size)
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(self.output_size_per_partition, self.input_size))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.output_size_per_partition))
|
||||
self.bias.weight_loader = self.weight_loader
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return F.linear(x, self.weight, self.bias)
|
||||
|
||||
|
||||
class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = False,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
super().__init__(input_size, sum(output_sizes), bias=bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: int | None = None):
|
||||
param_data = param.data
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
|
||||
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class QKVParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: int | None = None,
|
||||
bias: bool = False,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
if total_num_kv_heads is None:
|
||||
total_num_kv_heads = total_num_heads
|
||||
self.total_num_kv_heads = total_num_kv_heads
|
||||
# Divide the weight matrix along the last dimension.
|
||||
tp_size = 1 # get_tensor_model_parallel_world_size()
|
||||
self.num_heads = divide(self.total_num_heads, tp_size)
|
||||
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
self.output_sizes = [
|
||||
self.num_heads * self.head_size * tp_size, # q_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # k_proj
|
||||
self.num_kv_heads * self.head_size * tp_size, # v_proj
|
||||
]
|
||||
|
||||
super().__init__(input_size, output_size, bias)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, loaded_shard_id: str | None = None):
|
||||
param_data = param.data
|
||||
assert loaded_shard_id in ["q", "k", "v"]
|
||||
if loaded_shard_id == "q":
|
||||
shard_size = self.num_heads * self.head_size
|
||||
shard_offset = 0
|
||||
elif loaded_shard_id == "k":
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
shard_offset = self.num_heads * self.head_size
|
||||
else:
|
||||
shard_size = self.num_kv_heads * self.head_size
|
||||
shard_offset = self.num_heads * self.head_size + self.num_kv_heads * self.head_size
|
||||
param_data = param_data.narrow(self.tp_dim, shard_offset, shard_size)
|
||||
# loaded_weight = loaded_weight.narrow(self.tp_dim, self.tp_rank * shard_size, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class RowParallelLinear(LinearBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__(input_size, output_size, 1)
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
self.weight = nn.Parameter(torch.empty(self.output_size, self.input_size_per_partition))
|
||||
self.weight.weight_loader = self.weight_loader
|
||||
if bias:
|
||||
self.bias = nn.Parameter(torch.empty(self.output_size))
|
||||
self.bias.weight_loader = self.weight_loader
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
shard_size = param_data.size(self.tp_dim)
|
||||
start_idx = self.tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(self.tp_dim, start_idx, shard_size)
|
||||
assert param_data.size() == loaded_weight.size()
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
y = F.linear(x, self.weight, self.bias if self.tp_rank == 0 else None)
|
||||
if self.tp_size > 1:
|
||||
dist.all_reduce(y)
|
||||
return y
|
||||
73
nanovllm/layers/rotary_embedding.py
Normal file
73
nanovllm/layers/rotary_embedding.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from functools import lru_cache
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
cos = cos.unsqueeze(-2)
|
||||
sin = sin.unsqueeze(-2)
|
||||
x1, x2 = torch.chunk(x.to(torch.float32), 2, dim=-1)
|
||||
y1 = x1 * cos - x2 * sin
|
||||
y2 = x2 * cos + x1 * sin
|
||||
return torch.cat((y1, y2), dim=-1).to(x.dtype)
|
||||
|
||||
|
||||
class RotaryEmbedding(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
assert rotary_dim == head_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
@torch.compile
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query = apply_rotary_emb(query, cos, sin).view(query_shape)
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key = apply_rotary_emb(key, cos, sin).view(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def get_rope(
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: float,
|
||||
rope_scaling: dict | None = None,
|
||||
):
|
||||
assert rope_scaling is None
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base)
|
||||
return rotary_emb
|
||||
17
nanovllm/layers/sampler.py
Normal file
17
nanovllm/layers/sampler.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, logits: torch.Tensor, temperatures: torch.Tensor | None = None):
|
||||
logits = logits.to(torch.float)
|
||||
if temperatures is not None:
|
||||
logits.div_(temperatures.unsqueeze(dim=1))
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
|
||||
# logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
sampled_tokens = probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||
return sampled_tokens
|
||||
Reference in New Issue
Block a user