Compare commits
4 Commits
ff8b09cd35
...
tzj/minfer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9b52d25866 | ||
|
|
8c3418725b | ||
|
|
b3685c9190 | ||
|
|
6927a75ac3 |
74
CLAUDE.md
74
CLAUDE.md
@@ -20,6 +20,80 @@ For sparse attention related content (block sparse attention, MInference, FlexPr
|
||||
- **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096
|
||||
- **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload
|
||||
|
||||
## PyTorch Hooks for Debugging
|
||||
|
||||
### Hook Positions in Qwen3
|
||||
|
||||
```
|
||||
decoder_layer
|
||||
├── input_layernorm (RMSNorm)
|
||||
├── self_attn (Qwen3Attention) ← Hook here for attention I/O after o_proj
|
||||
│ ├── q_proj → q_norm → RoPE
|
||||
│ ├── k_proj → k_norm → RoPE
|
||||
│ ├── v_proj
|
||||
│ ├── attn (Attention) ← Hook here for Q/K/V tensors
|
||||
│ │ └── FlashAttention / SDPA
|
||||
│ └── o_proj
|
||||
├── post_attention_layernorm (RMSNorm)
|
||||
└── mlp (Qwen3MLP)
|
||||
```
|
||||
|
||||
### Hook Types & Data Shapes
|
||||
|
||||
| Hook Position | Type | Captured Data |
|
||||
|---------------|------|---------------|
|
||||
| `self_attn` | post | `[batch, seq_len, hidden_size]` - after o_proj |
|
||||
| `self_attn.attn` | pre | Q,K,V: `[seq_len, num_heads, head_dim]` - after RoPE |
|
||||
| `self_attn.attn` | post | `[seq_len, num_heads, head_dim]` - before o_proj |
|
||||
|
||||
### Example: Capture Attention Outputs
|
||||
|
||||
```python
|
||||
storage = {}
|
||||
|
||||
def make_hook(layer_id: int, storage: dict):
|
||||
def hook(module, inputs, output):
|
||||
if isinstance(output, tuple):
|
||||
attn_output = output[0]
|
||||
else:
|
||||
attn_output = output
|
||||
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
|
||||
if attn_output.dim() == 2:
|
||||
attn_output = attn_output.unsqueeze(0)
|
||||
storage[layer_id] = attn_output.detach().clone()
|
||||
return hook
|
||||
|
||||
# Register hooks
|
||||
hooks = []
|
||||
for layer_idx, layer in enumerate(model.model.layers):
|
||||
hooks.append(layer.self_attn.register_forward_hook(make_hook(layer_idx, storage)))
|
||||
|
||||
# Run inference...
|
||||
|
||||
# Cleanup
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
```
|
||||
|
||||
### Alignment Testing
|
||||
|
||||
Use `tests/test_align.py` to compare nanovllm with reference torch implementation:
|
||||
|
||||
```bash
|
||||
python tests/test_align.py
|
||||
```
|
||||
|
||||
Key files:
|
||||
- `tests/modeling_qwen3.py`: Reference Qwen3 implementation (torch + transformers only)
|
||||
- `tests/test_align.py`: Compares attention outputs between nanovllm and reference
|
||||
- `tests/test_needle_ref.py`: Reference needle test using custom Qwen3
|
||||
|
||||
### Common Pitfalls
|
||||
|
||||
1. **Shape mismatch**: nanovllm uses `[num_tokens, ...]` while torch uses `[batch, seq_len, ...]`
|
||||
2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj
|
||||
3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]`
|
||||
|
||||
## CPU Offload System
|
||||
|
||||
### Ring Buffer Design
|
||||
|
||||
757
tests/modeling_qwen3.py
Normal file
757
tests/modeling_qwen3.py
Normal file
@@ -0,0 +1,757 @@
|
||||
"""
|
||||
Custom Qwen3 implementation using only torch and transformers.
|
||||
This file provides a clean reference implementation for understanding the model computation graph.
|
||||
|
||||
Computation Graph:
|
||||
==================
|
||||
|
||||
Input: token_ids [batch, seq_len]
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Decoder Layer (x N) │
|
||||
│ ┌───────────────────────────────────────────────────┐ │
|
||||
│ │ Self Attention Block │ │
|
||||
│ │ │ │
|
||||
│ │ input_layernorm (RMSNorm) │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
||||
│ │ │ Qwen3Attention │ │ │
|
||||
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
|
||||
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
|
||||
│ │ │ V = v_proj(x) → reshape │ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ attn_output = attention(Q, K, V) │ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ output = o_proj(attn_output) │ │ │
|
||||
│ │ └─────────────────────────────────────────────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ hidden_states = residual + attn_output │ │
|
||||
│ └───────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌───────────────────────────────────────────────────┐ │
|
||||
│ │ MLP Block │ │
|
||||
│ │ │ │
|
||||
│ │ post_attention_layernorm (RMSNorm) │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
||||
│ │ │ Qwen3MLP │ │ │
|
||||
│ │ │ gate = gate_proj(x) │ │ │
|
||||
│ │ │ up = up_proj(x) │ │ │
|
||||
│ │ │ output = down_proj(silu(gate) * up) │ │ │
|
||||
│ │ └─────────────────────────────────────────────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ hidden_states = residual + mlp_output │ │
|
||||
│ └───────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ norm │ final RMSNorm
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ lm_head │ [hidden_size, vocab_size]
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
logits [batch, seq_len, vocab_size]
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Qwen3RMSNorm(nn.Module):
|
||||
"""RMSNorm implementation."""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = x.dtype
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.eps)
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
|
||||
class Qwen3RotaryEmbedding(nn.Module):
|
||||
"""Rotary Position Embedding (RoPE)."""
|
||||
|
||||
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
# Compute inverse frequencies
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
|
||||
position_ids: Position indices [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
cos, sin: [batch, seq_len, head_dim]
|
||||
"""
|
||||
# inv_freq: [dim/2]
|
||||
# position_ids: [batch, seq_len]
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
|
||||
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
|
||||
|
||||
# freqs: [batch, dim/2, seq_len]
|
||||
freqs = inv_freq_expanded @ position_ids_expanded
|
||||
# freqs: [batch, seq_len, dim/2]
|
||||
freqs = freqs.transpose(1, 2)
|
||||
|
||||
# Duplicate for full head_dim: [batch, seq_len, dim]
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
cos = emb.cos().to(x.dtype)
|
||||
sin = emb.sin().to(x.dtype)
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Rotate half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary position embeddings to Q and K.
|
||||
|
||||
Args:
|
||||
q: [batch, num_heads, seq_len, head_dim]
|
||||
k: [batch, num_kv_heads, seq_len, head_dim]
|
||||
cos: [batch, seq_len, head_dim]
|
||||
sin: [batch, seq_len, head_dim]
|
||||
|
||||
Returns:
|
||||
q_embed, k_embed with same shapes as inputs
|
||||
"""
|
||||
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
"""
|
||||
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
|
||||
|
||||
Data Flow:
|
||||
---------
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
|
||||
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
|
||||
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
|
||||
│
|
||||
▼
|
||||
apply_rotary_pos_emb(Q, K)
|
||||
│
|
||||
▼
|
||||
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
|
||||
│
|
||||
▼
|
||||
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
attention_bias: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
layer_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.num_kv_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_kv_groups = num_attention_heads // num_key_value_heads
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Scaling factor
|
||||
self.scaling = head_dim ** -0.5
|
||||
|
||||
# QKV projections
|
||||
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
|
||||
|
||||
# QK normalization (Qwen3 specific)
|
||||
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
||||
|
||||
# Rotary embeddings
|
||||
self.rotary_emb = Qwen3RotaryEmbedding(
|
||||
head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
|
||||
past_key_value: (k_cache, v_cache) from previous steps
|
||||
use_cache: Whether to return updated cache
|
||||
output_qkv: Whether to output Q, K, V tensors for debugging
|
||||
|
||||
Returns:
|
||||
output: [batch, seq_len, hidden_size]
|
||||
past_key_value: Updated cache (if use_cache=True)
|
||||
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
|
||||
"""
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# === QKV Projections ===
|
||||
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
|
||||
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
||||
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
||||
|
||||
# Reshape to [batch, seq_len, num_heads, head_dim]
|
||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
|
||||
# === QK Normalization (Qwen3 specific) ===
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Transpose to [batch, num_heads, seq_len, head_dim]
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# === Rotary Position Embeddings ===
|
||||
cos, sin = self.rotary_emb(v, position_ids)
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
# === KV Cache Update ===
|
||||
if past_key_value is not None:
|
||||
k_cache, v_cache = past_key_value
|
||||
k = torch.cat([k_cache, k], dim=2)
|
||||
v = torch.cat([v_cache, v], dim=2)
|
||||
|
||||
new_past_key_value = (k, v) if use_cache else None
|
||||
|
||||
# === Grouped Query Attention (expand KV heads if needed) ===
|
||||
if self.num_kv_groups > 1:
|
||||
# Repeat KV for each query group
|
||||
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
||||
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
||||
|
||||
# === Attention Computation (using SDPA for memory efficiency) ===
|
||||
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
|
||||
# is_causal only works when q_len == kv_len (prefill), not during decode
|
||||
q_len, kv_len = q.shape[2], k.shape[2]
|
||||
is_causal = (q_len == kv_len) and (q_len > 1)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=is_causal,
|
||||
scale=self.scaling,
|
||||
) # [batch, num_heads, seq_len, head_dim]
|
||||
|
||||
# === Output Projection ===
|
||||
# Transpose back and reshape
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
|
||||
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
|
||||
output = self.o_proj(attn_output)
|
||||
|
||||
# Optional QKV output for debugging
|
||||
qkv_dict = None
|
||||
if output_qkv:
|
||||
qkv_dict = {
|
||||
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
|
||||
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
|
||||
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
|
||||
}
|
||||
|
||||
return output, new_past_key_value, qkv_dict
|
||||
|
||||
|
||||
class Qwen3MLP(nn.Module):
|
||||
"""
|
||||
Qwen3 MLP with SwiGLU activation.
|
||||
|
||||
Data Flow:
|
||||
---------
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
|
||||
│
|
||||
└──► up_proj ──► up [batch, seq_len, intermediate_size]
|
||||
│
|
||||
▼
|
||||
silu(gate) * up
|
||||
│
|
||||
▼
|
||||
down_proj ──► output [batch, seq_len, hidden_size]
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate = self.gate_proj(x)
|
||||
up = self.up_proj(x)
|
||||
return self.down_proj(F.silu(gate) * up)
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(nn.Module):
|
||||
"""Single Qwen3 Decoder Layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
layer_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Pre-attention LayerNorm
|
||||
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
# Self-attention
|
||||
self.self_attn = Qwen3Attention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
attention_bias=attention_bias,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
# Post-attention LayerNorm
|
||||
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
# MLP
|
||||
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: Causal attention mask
|
||||
past_key_value: KV cache for this layer
|
||||
use_cache: Whether to return updated cache
|
||||
output_qkv: Whether to output Q, K, V for debugging
|
||||
|
||||
Returns:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
past_key_value: Updated cache
|
||||
qkv_dict: QKV tensors (if output_qkv=True)
|
||||
"""
|
||||
# === Self Attention Block ===
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
attn_output, new_past_key_value, qkv_dict = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_qkv=output_qkv,
|
||||
)
|
||||
|
||||
hidden_states = residual + attn_output
|
||||
|
||||
# === MLP Block ===
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states, new_past_key_value, qkv_dict
|
||||
|
||||
|
||||
class Qwen3Model(nn.Module):
|
||||
"""Qwen3 Transformer Model (without LM head)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_hidden_layers: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
||||
# Token embeddings
|
||||
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
|
||||
|
||||
# Decoder layers
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen3DecoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
mlp_bias=mlp_bias,
|
||||
layer_idx=i,
|
||||
)
|
||||
for i in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
# Final LayerNorm
|
||||
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv_layers: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: [batch, seq_len] or pre-computed 4D mask
|
||||
past_key_values: List of (k, v) tuples for each layer
|
||||
use_cache: Whether to return new cache
|
||||
output_qkv_layers: List of layer indices to output QKV for
|
||||
|
||||
Returns:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
new_past_key_values: Updated cache
|
||||
qkv_outputs: {layer_idx: qkv_dict}
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# Embedding
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Position IDs
|
||||
if position_ids is None:
|
||||
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Attention mask (create causal mask if not provided)
|
||||
if attention_mask is None or attention_mask.dim() == 2:
|
||||
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
|
||||
causal_mask = torch.triu(
|
||||
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
|
||||
diagonal=kv_seq_len - seq_len + 1,
|
||||
)
|
||||
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
|
||||
|
||||
# Initialize cache list
|
||||
new_past_key_values = [] if use_cache else None
|
||||
qkv_outputs = {} if output_qkv_layers else None
|
||||
|
||||
# Decoder layers
|
||||
for i, layer in enumerate(self.layers):
|
||||
past_kv = past_key_values[i] if past_key_values else None
|
||||
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
|
||||
|
||||
hidden_states, new_kv, qkv_dict = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_kv,
|
||||
use_cache=use_cache,
|
||||
output_qkv=output_qkv,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
new_past_key_values.append(new_kv)
|
||||
if qkv_dict is not None:
|
||||
qkv_outputs[i] = qkv_dict
|
||||
|
||||
# Final norm
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states, new_past_key_values, qkv_outputs
|
||||
|
||||
|
||||
class Qwen3ForCausalLM(nn.Module):
|
||||
"""Qwen3 Model with Language Modeling head."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_hidden_layers: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
tie_word_embeddings: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
|
||||
# Transformer model
|
||||
self.model = Qwen3Model(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
mlp_bias=mlp_bias,
|
||||
)
|
||||
|
||||
# LM head
|
||||
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv_layers: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
... (same as Qwen3Model)
|
||||
|
||||
Returns:
|
||||
logits: [batch, seq_len, vocab_size]
|
||||
past_key_values: Updated KV cache
|
||||
qkv_outputs: QKV tensors for specified layers
|
||||
"""
|
||||
hidden_states, new_past_key_values, qkv_outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_qkv_layers=output_qkv_layers,
|
||||
)
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
return logits, new_past_key_values, qkv_outputs
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
|
||||
"""
|
||||
Load weights from a pretrained Qwen3 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to model directory containing config.json and model weights
|
||||
dtype: Data type for model weights
|
||||
|
||||
Returns:
|
||||
Initialized Qwen3ForCausalLM model
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Create model
|
||||
model = cls(
|
||||
vocab_size=config["vocab_size"],
|
||||
hidden_size=config["hidden_size"],
|
||||
intermediate_size=config["intermediate_size"],
|
||||
num_hidden_layers=config["num_hidden_layers"],
|
||||
num_attention_heads=config["num_attention_heads"],
|
||||
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
|
||||
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
|
||||
max_position_embeddings=config.get("max_position_embeddings", 32768),
|
||||
rope_theta=config.get("rope_theta", 10000.0),
|
||||
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
|
||||
attention_bias=config.get("attention_bias", False),
|
||||
mlp_bias=config.get("mlp_bias", False),
|
||||
tie_word_embeddings=config.get("tie_word_embeddings", True),
|
||||
)
|
||||
|
||||
# Load weights
|
||||
weight_files = sorted([
|
||||
f for f in os.listdir(model_path)
|
||||
if f.endswith(".safetensors")
|
||||
])
|
||||
|
||||
state_dict = {}
|
||||
for wf in weight_files:
|
||||
state_dict.update(load_file(os.path.join(model_path, wf)))
|
||||
|
||||
# Load into model
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Tie lm_head weights to embed_tokens if configured
|
||||
if model.tie_word_embeddings:
|
||||
model.lm_head.weight = model.model.embed_tokens.weight
|
||||
|
||||
model = model.to(dtype)
|
||||
|
||||
return model
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
max_new_tokens: int = 32,
|
||||
temperature: float = 1.0,
|
||||
do_sample: bool = True,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Simple autoregressive generation."""
|
||||
device = input_ids.device
|
||||
batch_size, seq_len = input_ids.shape
|
||||
past_key_values = None
|
||||
generated = input_ids.clone()
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if past_key_values is None:
|
||||
current_input = generated
|
||||
else:
|
||||
current_input = generated[:, -1:]
|
||||
|
||||
logits, past_key_values, _ = self(
|
||||
input_ids=current_input,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
next_token_logits = logits[:, -1, :]
|
||||
if temperature > 0 and do_sample:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
probs = torch.softmax(next_token_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
|
||||
|
||||
generated = torch.cat([generated, next_token], dim=1)
|
||||
|
||||
if eos_token_id is not None and (next_token == eos_token_id).all():
|
||||
break
|
||||
|
||||
return generated
|
||||
|
||||
|
||||
def print_computation_graph():
|
||||
"""Print the computation graph for reference."""
|
||||
print(__doc__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_computation_graph()
|
||||
212
tests/test_align.py
Normal file
212
tests/test_align.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""
|
||||
Test alignment between nanovllm and custom torch Qwen3 implementation.
|
||||
Compares attention layer outputs to verify correctness.
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from modeling_qwen3 import Qwen3ForCausalLM
|
||||
from utils import generate_needle_prompt
|
||||
|
||||
# Config
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||
INPUT_LEN = 512 # Use shorter length for alignment test
|
||||
DTYPE = torch.float16
|
||||
|
||||
# Storage for captured tensors
|
||||
nanovllm_outputs = {}
|
||||
torch_outputs = {}
|
||||
|
||||
|
||||
def make_nanovllm_hook(layer_id: int, storage: dict):
|
||||
"""Capture nanovllm self_attn outputs (after o_proj)."""
|
||||
def hook(module, inputs, output):
|
||||
# Qwen3Attention output is a tuple (attn_output, None)
|
||||
if isinstance(output, tuple):
|
||||
attn_output = output[0]
|
||||
else:
|
||||
attn_output = output
|
||||
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
|
||||
if attn_output.dim() == 2:
|
||||
attn_output = attn_output.unsqueeze(0)
|
||||
storage[layer_id] = attn_output.detach().clone()
|
||||
return hook
|
||||
|
||||
|
||||
def make_torch_hook(layer_id: int, storage: dict):
|
||||
"""Capture torch model self_attn outputs (after o_proj)."""
|
||||
def hook(module, inputs, output):
|
||||
# Qwen3Attention output is (attn_output, past_kv, qkv_dict)
|
||||
attn_output, _, _ = output
|
||||
storage[layer_id] = attn_output.detach().clone()
|
||||
return hook
|
||||
|
||||
|
||||
def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-2):
|
||||
"""Compare two tensors and print statistics."""
|
||||
# Handle shape differences
|
||||
if t1.shape != t2.shape:
|
||||
print(f"[{name}] Shape mismatch: {t1.shape} vs {t2.shape}")
|
||||
# Try to reshape for comparison if possible
|
||||
if t1.numel() == t2.numel():
|
||||
t2 = t2.view(t1.shape)
|
||||
else:
|
||||
return False
|
||||
|
||||
diff = (t1.float() - t2.float()).abs()
|
||||
max_diff = diff.max().item()
|
||||
mean_diff = diff.mean().item()
|
||||
|
||||
passed = max_diff < atol
|
||||
status = "PASS" if passed else "FAIL"
|
||||
|
||||
print(f"[{name}] {status}")
|
||||
print(f" Shape: {list(t1.shape)}")
|
||||
print(f" t1 mean: {t1.float().mean():.6f}, std: {t1.float().std():.6f}")
|
||||
print(f" t2 mean: {t2.float().mean():.6f}, std: {t2.float().std():.6f}")
|
||||
print(f" Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}")
|
||||
|
||||
return passed
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Load nanovllm model
|
||||
# ============================================================
|
||||
print("=" * 60)
|
||||
print("Loading nanovllm model...")
|
||||
print("=" * 60)
|
||||
|
||||
llm = LLM(
|
||||
MODEL_PATH,
|
||||
enforce_eager=True,
|
||||
max_model_len=4096,
|
||||
max_num_batched_tokens=4096,
|
||||
enable_cpu_offload=False, # Disable offload for alignment test
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Load torch model
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Loading custom torch model...")
|
||||
print("=" * 60)
|
||||
|
||||
torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
|
||||
torch_model = torch_model.to("cuda")
|
||||
torch_model.eval()
|
||||
|
||||
# ============================================================
|
||||
# Generate test input
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Generating test input...")
|
||||
print("=" * 60)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
prompt, _ = generate_needle_prompt(
|
||||
tokenizer=tokenizer,
|
||||
target_length=INPUT_LEN,
|
||||
verbose=True,
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
|
||||
# ============================================================
|
||||
# Register hooks on both models
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Registering hooks...")
|
||||
print("=" * 60)
|
||||
|
||||
# Hook on nanovllm (self_attn is Qwen3Attention, captures output after o_proj)
|
||||
nanovllm_hooks = []
|
||||
for layer_idx, layer in enumerate(llm.model_runner.model.model.layers):
|
||||
if layer_idx >= 2: # Only first 2 layers
|
||||
break
|
||||
nanovllm_hooks.append(
|
||||
layer.self_attn.register_forward_hook(
|
||||
make_nanovllm_hook(layer_idx, nanovllm_outputs)
|
||||
)
|
||||
)
|
||||
print(f" Registered nanovllm hook on layer {layer_idx} self_attn")
|
||||
|
||||
# Hook on torch model (self_attn is Qwen3Attention, captures output after o_proj)
|
||||
torch_hooks = []
|
||||
for layer_idx, layer in enumerate(torch_model.model.layers):
|
||||
if layer_idx >= 2: # Only first 2 layers
|
||||
break
|
||||
torch_hooks.append(
|
||||
layer.self_attn.register_forward_hook(
|
||||
make_torch_hook(layer_idx, torch_outputs)
|
||||
)
|
||||
)
|
||||
print(f" Registered torch hook on layer {layer_idx} self_attn")
|
||||
|
||||
# ============================================================
|
||||
# Run nanovllm inference
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Running nanovllm inference...")
|
||||
print("=" * 60)
|
||||
|
||||
# Use prompt_token_ids to ensure same input
|
||||
prompt_token_ids = input_ids[0].tolist()
|
||||
nanovllm_result = llm.generate(
|
||||
[prompt_token_ids],
|
||||
SamplingParams(temperature=0.01, max_tokens=1), # Near-greedy for determinism
|
||||
use_tqdm=False,
|
||||
)
|
||||
|
||||
# ============================================================
|
||||
# Run torch inference
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Running torch inference...")
|
||||
print("=" * 60)
|
||||
|
||||
with torch.no_grad():
|
||||
torch_logits, _, _ = torch_model(input_ids)
|
||||
|
||||
# ============================================================
|
||||
# Compare outputs
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
print("Comparing attention outputs...")
|
||||
print("=" * 60)
|
||||
|
||||
all_passed = True
|
||||
for layer_idx in sorted(nanovllm_outputs.keys()):
|
||||
if layer_idx not in torch_outputs:
|
||||
print(f"[Layer {layer_idx}] Missing torch output")
|
||||
all_passed = False
|
||||
continue
|
||||
|
||||
nano_out = nanovllm_outputs[layer_idx]
|
||||
torch_out = torch_outputs[layer_idx]
|
||||
|
||||
print(f"\n--- Layer {layer_idx} ---")
|
||||
passed = compare_tensors(f"Layer {layer_idx} attn_output", nano_out, torch_out, atol=0.1)
|
||||
all_passed = all_passed and passed
|
||||
|
||||
# ============================================================
|
||||
# Cleanup
|
||||
# ============================================================
|
||||
for hook in nanovllm_hooks:
|
||||
hook.remove()
|
||||
for hook in torch_hooks:
|
||||
hook.remove()
|
||||
|
||||
# ============================================================
|
||||
# Result
|
||||
# ============================================================
|
||||
print("\n" + "=" * 60)
|
||||
if all_passed:
|
||||
print("test_align: PASSED - nanovllm and torch outputs aligned!")
|
||||
else:
|
||||
print("test_align: FAILED - outputs differ!")
|
||||
print("=" * 60)
|
||||
@@ -12,151 +12,7 @@ os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
||||
|
||||
import argparse
|
||||
from nanovllm import LLM, SamplingParams
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Needle Test Generator
|
||||
# ============================================================
|
||||
|
||||
def generate_needle_prompt(
|
||||
tokenizer,
|
||||
target_length: int,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
use_chat_template: bool = True,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
||||
|
||||
Args:
|
||||
tokenizer: HuggingFace tokenizer for length estimation
|
||||
target_length: Target total sequence length in tokens
|
||||
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
|
||||
needle_value: The secret value to hide in the haystack
|
||||
use_chat_template: Whether to use chat template for instruct models
|
||||
|
||||
Returns:
|
||||
(prompt, expected_answer): The full prompt and the expected needle value
|
||||
"""
|
||||
# Haystack filler paragraphs (various topics to create realistic context)
|
||||
haystack_paragraphs = [
|
||||
"The weather today is quite pleasant with clear skies and moderate temperatures. "
|
||||
"Many people are enjoying outdoor activities in the park. "
|
||||
"Birds are singing in the trees and children are playing on the swings. ",
|
||||
|
||||
"In the world of technology, new innovations continue to emerge every day. "
|
||||
"Researchers are working on advanced algorithms and computing systems. "
|
||||
"The future of artificial intelligence looks promising with many breakthroughs. ",
|
||||
|
||||
"The history of human civilization spans thousands of years. "
|
||||
"Ancient cultures developed writing, mathematics, and astronomy. "
|
||||
"Trade routes connected distant lands and facilitated cultural exchange. ",
|
||||
|
||||
"Modern cooking combines traditional techniques with new ingredients. "
|
||||
"Chefs around the world experiment with flavors and presentations. "
|
||||
"Food brings people together and creates memorable experiences. ",
|
||||
|
||||
"The ocean covers more than seventy percent of Earth's surface. "
|
||||
"Marine ecosystems support an incredible diversity of life forms. "
|
||||
"Scientists continue to discover new species in the deep sea. ",
|
||||
|
||||
"Music has been a part of human culture since prehistoric times. "
|
||||
"Different genres evolved across various regions and time periods. "
|
||||
"Today, people can access millions of songs through digital platforms. ",
|
||||
|
||||
"Space exploration has revealed many secrets about our universe. "
|
||||
"Telescopes can observe galaxies billions of light years away. "
|
||||
"Future missions aim to establish human presence on other planets. ",
|
||||
|
||||
"The study of languages reveals patterns in human cognition. "
|
||||
"Linguists analyze grammar, semantics, and phonetics across cultures. "
|
||||
"Language continues to evolve with new words and expressions. ",
|
||||
]
|
||||
|
||||
# The needle sentence
|
||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||
|
||||
# Question at the end
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
|
||||
# Estimate tokens for fixed parts
|
||||
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
||||
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
||||
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
||||
# Buffer for chat template, special tokens, etc.
|
||||
overhead_tokens = 100 if use_chat_template else 50
|
||||
|
||||
# Available tokens for haystack
|
||||
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
||||
if haystack_target_tokens < 100:
|
||||
raise ValueError(f"target_length {target_length} is too short for needle test")
|
||||
|
||||
# Build haystack by repeating paragraphs
|
||||
haystack_parts = []
|
||||
current_tokens = 0
|
||||
para_idx = 0
|
||||
|
||||
while current_tokens < haystack_target_tokens:
|
||||
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
|
||||
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
||||
if current_tokens + para_tokens > haystack_target_tokens:
|
||||
break
|
||||
haystack_parts.append(para)
|
||||
current_tokens += para_tokens
|
||||
para_idx += 1
|
||||
|
||||
# Calculate needle insertion point
|
||||
needle_idx = int(len(haystack_parts) * needle_position)
|
||||
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
||||
|
||||
# Insert needle
|
||||
haystack_parts.insert(needle_idx, needle)
|
||||
|
||||
# Assemble prompt
|
||||
full_text = "".join(haystack_parts)
|
||||
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
# Use chat template for instruct models
|
||||
# For Qwen3, add /no_think to disable thinking mode
|
||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||
messages = [
|
||||
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
# Raw text format for base models
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
prompt = full_text + question
|
||||
|
||||
# Verify length
|
||||
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
||||
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
||||
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
||||
|
||||
return prompt, needle_value
|
||||
|
||||
|
||||
def check_needle_answer(output_text: str, expected: str) -> bool:
|
||||
"""Check if the model output contains the expected needle value."""
|
||||
import re
|
||||
# Clean output - remove special tokens and whitespace
|
||||
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
||||
output_clean = ' '.join(output_clean.split()).lower()
|
||||
expected_clean = expected.strip().lower()
|
||||
|
||||
# Check if expected value appears in output
|
||||
# Also try to find it as a standalone number
|
||||
if expected_clean in output_clean:
|
||||
return True
|
||||
|
||||
# Try to extract numbers and check if expected is among them
|
||||
numbers = re.findall(r'\d+', output_clean)
|
||||
return expected_clean in numbers
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
# ============================================================
|
||||
|
||||
@@ -8,148 +8,9 @@ Uses standard HuggingFace inference (no custom KV cache, no offload).
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Needle Test Generator
|
||||
# ============================================================
|
||||
|
||||
def generate_needle_prompt(
|
||||
tokenizer,
|
||||
target_length: int,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
use_chat_template: bool = True,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
||||
|
||||
Args:
|
||||
tokenizer: HuggingFace tokenizer for length estimation
|
||||
target_length: Target total sequence length in tokens
|
||||
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
|
||||
needle_value: The secret value to hide in the haystack
|
||||
use_chat_template: Whether to use chat template for instruct models
|
||||
|
||||
Returns:
|
||||
(prompt, expected_answer): The full prompt and the expected needle value
|
||||
"""
|
||||
# Haystack filler paragraphs (various topics to create realistic context)
|
||||
haystack_paragraphs = [
|
||||
"The weather today is quite pleasant with clear skies and moderate temperatures. "
|
||||
"Many people are enjoying outdoor activities in the park. "
|
||||
"Birds are singing in the trees and children are playing on the swings. ",
|
||||
|
||||
"In the world of technology, new innovations continue to emerge every day. "
|
||||
"Researchers are working on advanced algorithms and computing systems. "
|
||||
"The future of artificial intelligence looks promising with many breakthroughs. ",
|
||||
|
||||
"The history of human civilization spans thousands of years. "
|
||||
"Ancient cultures developed writing, mathematics, and astronomy. "
|
||||
"Trade routes connected distant lands and facilitated cultural exchange. ",
|
||||
|
||||
"Modern cooking combines traditional techniques with new ingredients. "
|
||||
"Chefs around the world experiment with flavors and presentations. "
|
||||
"Food brings people together and creates memorable experiences. ",
|
||||
|
||||
"The ocean covers more than seventy percent of Earth's surface. "
|
||||
"Marine ecosystems support an incredible diversity of life forms. "
|
||||
"Scientists continue to discover new species in the deep sea. ",
|
||||
|
||||
"Music has been a part of human culture since prehistoric times. "
|
||||
"Different genres evolved across various regions and time periods. "
|
||||
"Today, people can access millions of songs through digital platforms. ",
|
||||
|
||||
"Space exploration has revealed many secrets about our universe. "
|
||||
"Telescopes can observe galaxies billions of light years away. "
|
||||
"Future missions aim to establish human presence on other planets. ",
|
||||
|
||||
"The study of languages reveals patterns in human cognition. "
|
||||
"Linguists analyze grammar, semantics, and phonetics across cultures. "
|
||||
"Language continues to evolve with new words and expressions. ",
|
||||
]
|
||||
|
||||
# The needle sentence
|
||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||
|
||||
# Estimate tokens for fixed parts
|
||||
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
||||
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
||||
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
||||
# Buffer for chat template, special tokens, etc.
|
||||
overhead_tokens = 100 if use_chat_template else 50
|
||||
|
||||
# Available tokens for haystack
|
||||
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
||||
if haystack_target_tokens < 100:
|
||||
raise ValueError(f"target_length {target_length} is too short for needle test")
|
||||
|
||||
# Build haystack by repeating paragraphs
|
||||
haystack_parts = []
|
||||
current_tokens = 0
|
||||
para_idx = 0
|
||||
|
||||
while current_tokens < haystack_target_tokens:
|
||||
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
|
||||
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
||||
if current_tokens + para_tokens > haystack_target_tokens:
|
||||
break
|
||||
haystack_parts.append(para)
|
||||
current_tokens += para_tokens
|
||||
para_idx += 1
|
||||
|
||||
# Calculate needle insertion point
|
||||
needle_idx = int(len(haystack_parts) * needle_position)
|
||||
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
||||
|
||||
# Insert needle
|
||||
haystack_parts.insert(needle_idx, needle)
|
||||
|
||||
# Assemble prompt
|
||||
full_text = "".join(haystack_parts)
|
||||
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
# Use chat template for instruct models
|
||||
# For Qwen3, add /no_think to disable thinking mode
|
||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||
messages = [
|
||||
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
# Raw text format for base models
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
prompt = full_text + question
|
||||
|
||||
# Verify length
|
||||
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
||||
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
||||
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
||||
|
||||
return prompt, needle_value
|
||||
|
||||
|
||||
def check_needle_answer(output_text: str, expected: str) -> bool:
|
||||
"""Check if the model output contains the expected needle value."""
|
||||
import re
|
||||
# Clean output - remove special tokens and whitespace
|
||||
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
||||
output_clean = ' '.join(output_clean.split()).lower()
|
||||
expected_clean = expected.strip().lower()
|
||||
|
||||
# Check if expected value appears in output
|
||||
if expected_clean in output_clean:
|
||||
return True
|
||||
|
||||
# Try to extract numbers and check if expected is among them
|
||||
numbers = re.findall(r'\d+', output_clean)
|
||||
return expected_clean in numbers
|
||||
from transformers import AutoTokenizer
|
||||
from modeling_qwen3 import Qwen3ForCausalLM
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
# ============================================================
|
||||
@@ -207,22 +68,19 @@ def run_needle_test(
|
||||
# 3. Load model
|
||||
print("[3/4] Loading model...")
|
||||
torch_dtype = {
|
||||
"auto": "auto",
|
||||
"auto": torch.float16, # default to float16 for custom model
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}.get(dtype, "auto")
|
||||
}.get(dtype, torch.float16)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
|
||||
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.eval()
|
||||
|
||||
# 4. Generate output
|
||||
print("[4/4] Running inference...")
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
|
||||
device = next(model.parameters()).device
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
||||
print(f" Input shape: {input_ids.shape}")
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
175
tests/utils.py
Normal file
175
tests/utils.py
Normal file
@@ -0,0 +1,175 @@
|
||||
"""
|
||||
Test utilities for nano-vllm.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Needle-in-Haystack Test Utilities
|
||||
# ============================================================
|
||||
|
||||
# Haystack filler paragraphs (various topics to create realistic context)
|
||||
HAYSTACK_PARAGRAPHS = [
|
||||
"The weather today is quite pleasant with clear skies and moderate temperatures. "
|
||||
"Many people are enjoying outdoor activities in the park. "
|
||||
"Birds are singing in the trees and children are playing on the swings. ",
|
||||
|
||||
"In the world of technology, new innovations continue to emerge every day. "
|
||||
"Researchers are working on advanced algorithms and computing systems. "
|
||||
"The future of artificial intelligence looks promising with many breakthroughs. ",
|
||||
|
||||
"The history of human civilization spans thousands of years. "
|
||||
"Ancient cultures developed writing, mathematics, and astronomy. "
|
||||
"Trade routes connected distant lands and facilitated cultural exchange. ",
|
||||
|
||||
"Modern cooking combines traditional techniques with new ingredients. "
|
||||
"Chefs around the world experiment with flavors and presentations. "
|
||||
"Food brings people together and creates memorable experiences. ",
|
||||
|
||||
"The ocean covers more than seventy percent of Earth's surface. "
|
||||
"Marine ecosystems support an incredible diversity of life forms. "
|
||||
"Scientists continue to discover new species in the deep sea. ",
|
||||
|
||||
"Music has been a part of human culture since prehistoric times. "
|
||||
"Different genres evolved across various regions and time periods. "
|
||||
"Today, people can access millions of songs through digital platforms. ",
|
||||
|
||||
"Space exploration has revealed many secrets about our universe. "
|
||||
"Telescopes can observe galaxies billions of light years away. "
|
||||
"Future missions aim to establish human presence on other planets. ",
|
||||
|
||||
"The study of languages reveals patterns in human cognition. "
|
||||
"Linguists analyze grammar, semantics, and phonetics across cultures. "
|
||||
"Language continues to evolve with new words and expressions. ",
|
||||
]
|
||||
|
||||
|
||||
def generate_needle_prompt(
|
||||
tokenizer,
|
||||
target_length: int,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
use_chat_template: bool = True,
|
||||
verbose: bool = True,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
||||
|
||||
Args:
|
||||
tokenizer: HuggingFace tokenizer for length estimation
|
||||
target_length: Target total sequence length in tokens
|
||||
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
|
||||
needle_value: The secret value to hide in the haystack
|
||||
use_chat_template: Whether to use chat template for instruct models
|
||||
verbose: Whether to print generation info
|
||||
|
||||
Returns:
|
||||
(prompt, expected_answer): The full prompt and the expected needle value
|
||||
"""
|
||||
# The needle sentence
|
||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||
|
||||
# Question at the end
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
|
||||
# Estimate tokens for fixed parts
|
||||
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
||||
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
||||
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
||||
# Buffer for chat template, special tokens, etc.
|
||||
overhead_tokens = 100 if use_chat_template else 50
|
||||
|
||||
# Available tokens for haystack
|
||||
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
||||
if haystack_target_tokens < 100:
|
||||
raise ValueError(f"target_length {target_length} is too short for needle test")
|
||||
|
||||
# Build haystack by repeating paragraphs
|
||||
haystack_parts = []
|
||||
current_tokens = 0
|
||||
para_idx = 0
|
||||
|
||||
while current_tokens < haystack_target_tokens:
|
||||
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
||||
if current_tokens + para_tokens > haystack_target_tokens:
|
||||
break
|
||||
haystack_parts.append(para)
|
||||
current_tokens += para_tokens
|
||||
para_idx += 1
|
||||
|
||||
# Calculate needle insertion point
|
||||
needle_idx = int(len(haystack_parts) * needle_position)
|
||||
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
||||
|
||||
# Insert needle
|
||||
haystack_parts.insert(needle_idx, needle)
|
||||
|
||||
# Assemble prompt
|
||||
full_text = "".join(haystack_parts)
|
||||
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
# Use chat template for instruct models
|
||||
# For Qwen3, add /no_think to disable thinking mode
|
||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||
messages = [
|
||||
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
# Raw text format for base models
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
prompt = full_text + question
|
||||
|
||||
# Verify length
|
||||
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
if verbose:
|
||||
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
||||
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
||||
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
||||
|
||||
return prompt, needle_value
|
||||
|
||||
|
||||
def check_needle_answer(output_text: str, expected: str) -> bool:
|
||||
"""Check if the model output contains the expected needle value."""
|
||||
# Clean output - remove special tokens and whitespace
|
||||
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
||||
output_clean = ' '.join(output_clean.split()).lower()
|
||||
expected_clean = expected.strip().lower()
|
||||
|
||||
# Check if expected value appears in output
|
||||
# Also try to find it as a standalone number
|
||||
if expected_clean in output_clean:
|
||||
return True
|
||||
|
||||
# Try to extract numbers and check if expected is among them
|
||||
numbers = re.findall(r'\d+', output_clean)
|
||||
return expected_clean in numbers
|
||||
|
||||
|
||||
def generate_random_token_ids(
|
||||
length: int,
|
||||
vocab_size: int = 10000,
|
||||
seed: int = 42,
|
||||
) -> list:
|
||||
"""
|
||||
Generate random token IDs for testing.
|
||||
|
||||
Args:
|
||||
length: Number of tokens to generate
|
||||
vocab_size: Maximum token ID (exclusive)
|
||||
seed: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
List of random token IDs
|
||||
"""
|
||||
from random import randint, seed as set_seed
|
||||
set_seed(seed)
|
||||
return [randint(0, vocab_size - 1) for _ in range(length)]
|
||||
Reference in New Issue
Block a user