support qwen2

This commit is contained in:
GeeeekExplorer
2025-11-04 01:44:09 +08:00
parent db1b49dce4
commit 2f21442653
3 changed files with 15 additions and 9 deletions

View File

@@ -22,9 +22,9 @@ A lightweight vLLM implementation built from scratch.
pip install git+https://github.com/GeeeekExplorer/nano-vllm.git pip install git+https://github.com/GeeeekExplorer/nano-vllm.git
``` ```
## Manual Download ## Model Download
If you prefer to download the model weights manually, use the following command: To download the model weights manually, use the following command:
```bash ```bash
huggingface-cli download --resume-download Qwen/Qwen3-0.6B \ huggingface-cli download --resume-download Qwen/Qwen3-0.6B \
--local-dir ~/huggingface/Qwen3-0.6B/ \ --local-dir ~/huggingface/Qwen3-0.6B/ \

View File

@@ -105,10 +105,11 @@ class ModelRunner:
peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
current = torch.cuda.memory_stats()["allocated_bytes.all.current"] current = torch.cuda.memory_stats()["allocated_bytes.all.current"]
num_kv_heads = hf_config.num_key_value_heads // self.world_size num_kv_heads = hf_config.num_key_value_heads // self.world_size
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize head_dim = getattr(hf_config, "head_dim", hf_config.hidden_size // hf_config.num_attention_heads)
block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * head_dim * hf_config.torch_dtype.itemsize
config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes config.num_kvcache_blocks = int(total * config.gpu_memory_utilization - used - peak + current) // block_bytes
assert config.num_kvcache_blocks > 0 assert config.num_kvcache_blocks > 0
self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) self.kv_cache = torch.empty(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, head_dim)
layer_id = 0 layer_id = 0
for module in self.model.modules(): for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"): if hasattr(module, "k_cache") and hasattr(module, "v_cache"):

View File

@@ -37,6 +37,7 @@ class Qwen3Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim ** -0.5 self.scaling = self.head_dim ** -0.5
self.qkv_bias = qkv_bias
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
hidden_size, hidden_size,
@@ -63,6 +64,7 @@ class Qwen3Attention(nn.Module):
self.scaling, self.scaling,
self.num_kv_heads, self.num_kv_heads,
) )
if not self.qkv_bias:
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
@@ -73,9 +75,12 @@ class Qwen3Attention(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
qkv = self.qkv_proj(hidden_states) qkv = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)) q = q.view(-1, self.num_heads, self.head_dim)
k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)) k = k.view(-1, self.num_kv_heads, self.head_dim)
v = v.view(-1, self.num_kv_heads, self.head_dim) v = v.view(-1, self.num_kv_heads, self.head_dim)
if not self.qkv_bias:
q = self.q_norm(q)
k = self.k_norm(k)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
o = self.attn(q, k, v) o = self.attn(q, k, v)
output = self.o_proj(o.flatten(1, -1)) output = self.o_proj(o.flatten(1, -1))
@@ -124,7 +129,7 @@ class Qwen3DecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
rms_norm_eps=config.rms_norm_eps, rms_norm_eps=config.rms_norm_eps,
qkv_bias=getattr(config, 'attention_bias', False), qkv_bias=getattr(config, 'attention_bias', True),
head_dim=getattr(config, 'head_dim', None), head_dim=getattr(config, 'head_dim', None),
rope_theta=getattr(config, "rope_theta", 1000000), rope_theta=getattr(config, "rope_theta", 1000000),
rope_scaling=getattr(config, "rope_scaling", None), rope_scaling=getattr(config, "rope_scaling", None),