Compare commits

...

4 Commits

Author SHA1 Message Date
Zijie Tian
9b52d25866 [docs] Update CLAUDE.md. 2026-01-03 20:46:00 +08:00
Zijie Tian
8c3418725b [refactor] Refactor needle test. 2026-01-03 19:19:37 +08:00
Zijie Tian
b3685c9190 [test] Added test_align.py 2026-01-03 18:55:58 +08:00
Zijie Tian
6927a75ac3 [refactor] refactor needle.py. 2026-01-03 18:33:48 +08:00
6 changed files with 1228 additions and 296 deletions

View File

@@ -20,6 +20,80 @@ For sparse attention related content (block sparse attention, MInference, FlexPr
- **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096
- **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload
## PyTorch Hooks for Debugging
### Hook Positions in Qwen3
```
decoder_layer
├── input_layernorm (RMSNorm)
├── self_attn (Qwen3Attention) ← Hook here for attention I/O after o_proj
│ ├── q_proj → q_norm → RoPE
│ ├── k_proj → k_norm → RoPE
│ ├── v_proj
│ ├── attn (Attention) ← Hook here for Q/K/V tensors
│ │ └── FlashAttention / SDPA
│ └── o_proj
├── post_attention_layernorm (RMSNorm)
└── mlp (Qwen3MLP)
```
### Hook Types & Data Shapes
| Hook Position | Type | Captured Data |
|---------------|------|---------------|
| `self_attn` | post | `[batch, seq_len, hidden_size]` - after o_proj |
| `self_attn.attn` | pre | Q,K,V: `[seq_len, num_heads, head_dim]` - after RoPE |
| `self_attn.attn` | post | `[seq_len, num_heads, head_dim]` - before o_proj |
### Example: Capture Attention Outputs
```python
storage = {}
def make_hook(layer_id: int, storage: dict):
def hook(module, inputs, output):
if isinstance(output, tuple):
attn_output = output[0]
else:
attn_output = output
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
if attn_output.dim() == 2:
attn_output = attn_output.unsqueeze(0)
storage[layer_id] = attn_output.detach().clone()
return hook
# Register hooks
hooks = []
for layer_idx, layer in enumerate(model.model.layers):
hooks.append(layer.self_attn.register_forward_hook(make_hook(layer_idx, storage)))
# Run inference...
# Cleanup
for hook in hooks:
hook.remove()
```
### Alignment Testing
Use `tests/test_align.py` to compare nanovllm with reference torch implementation:
```bash
python tests/test_align.py
```
Key files:
- `tests/modeling_qwen3.py`: Reference Qwen3 implementation (torch + transformers only)
- `tests/test_align.py`: Compares attention outputs between nanovllm and reference
- `tests/test_needle_ref.py`: Reference needle test using custom Qwen3
### Common Pitfalls
1. **Shape mismatch**: nanovllm uses `[num_tokens, ...]` while torch uses `[batch, seq_len, ...]`
2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj
3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]`
## CPU Offload System
### Ring Buffer Design

757
tests/modeling_qwen3.py Normal file
View File

@@ -0,0 +1,757 @@
"""
Custom Qwen3 implementation using only torch and transformers.
This file provides a clean reference implementation for understanding the model computation graph.
Computation Graph:
==================
Input: token_ids [batch, seq_len]
┌─────────────┐
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
└─────────────┘
hidden_states [batch, seq_len, hidden_size]
┌─────────────────────────────────────────────────────────┐
│ Decoder Layer (x N) │
│ ┌───────────────────────────────────────────────────┐ │
│ │ Self Attention Block │ │
│ │ │ │
│ │ input_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3Attention │ │ │
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
│ │ │ V = v_proj(x) → reshape │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ attn_output = attention(Q, K, V) │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ output = o_proj(attn_output) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + attn_output │ │
│ └───────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ MLP Block │ │
│ │ │ │
│ │ post_attention_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3MLP │ │ │
│ │ │ gate = gate_proj(x) │ │ │
│ │ │ up = up_proj(x) │ │ │
│ │ │ output = down_proj(silu(gate) * up) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + mlp_output │ │
│ └───────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
┌─────────────┐
│ norm │ final RMSNorm
└─────────────┘
┌─────────────┐
│ lm_head │ [hidden_size, vocab_size]
└─────────────┘
logits [batch, seq_len, vocab_size]
"""
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
class Qwen3RMSNorm(nn.Module):
"""RMSNorm implementation."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x.to(input_dtype)
class Qwen3RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
position_ids: Position indices [batch, seq_len]
Returns:
cos, sin: [batch, seq_len, head_dim]
"""
# inv_freq: [dim/2]
# position_ids: [batch, seq_len]
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
# freqs: [batch, dim/2, seq_len]
freqs = inv_freq_expanded @ position_ids_expanded
# freqs: [batch, seq_len, dim/2]
freqs = freqs.transpose(1, 2)
# Duplicate for full head_dim: [batch, seq_len, dim]
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(x.dtype)
sin = emb.sin().to(x.dtype)
return cos, sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embeddings to Q and K.
Args:
q: [batch, num_heads, seq_len, head_dim]
k: [batch, num_kv_heads, seq_len, head_dim]
cos: [batch, seq_len, head_dim]
sin: [batch, seq_len, head_dim]
Returns:
q_embed, k_embed with same shapes as inputs
"""
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen3Attention(nn.Module):
"""
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
apply_rotary_pos_emb(Q, K)
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
attention_bias: bool = False,
rms_norm_eps: float = 1e-6,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.num_kv_heads = num_key_value_heads
self.head_dim = head_dim
self.num_kv_groups = num_attention_heads // num_key_value_heads
self.layer_idx = layer_idx
# Scaling factor
self.scaling = head_dim ** -0.5
# QKV projections
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
# QK normalization (Qwen3 specific)
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
# Rotary embeddings
self.rotary_emb = Qwen3RotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
base=rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
past_key_value: (k_cache, v_cache) from previous steps
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V tensors for debugging
Returns:
output: [batch, seq_len, hidden_size]
past_key_value: Updated cache (if use_cache=True)
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
"""
batch_size, seq_len, _ = hidden_states.shape
# === QKV Projections ===
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
# Reshape to [batch, seq_len, num_heads, head_dim]
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# === QK Normalization (Qwen3 specific) ===
q = self.q_norm(q)
k = self.k_norm(k)
# Transpose to [batch, num_heads, seq_len, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# === Rotary Position Embeddings ===
cos, sin = self.rotary_emb(v, position_ids)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# === KV Cache Update ===
if past_key_value is not None:
k_cache, v_cache = past_key_value
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
new_past_key_value = (k, v) if use_cache else None
# === Grouped Query Attention (expand KV heads if needed) ===
if self.num_kv_groups > 1:
# Repeat KV for each query group
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
# === Attention Computation (using SDPA for memory efficiency) ===
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
# is_causal only works when q_len == kv_len (prefill), not during decode
q_len, kv_len = q.shape[2], k.shape[2]
is_causal = (q_len == kv_len) and (q_len > 1)
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=is_causal,
scale=self.scaling,
) # [batch, num_heads, seq_len, head_dim]
# === Output Projection ===
# Transpose back and reshape
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
output = self.o_proj(attn_output)
# Optional QKV output for debugging
qkv_dict = None
if output_qkv:
qkv_dict = {
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
}
return output, new_past_key_value, qkv_dict
class Qwen3MLP(nn.Module):
"""
Qwen3 MLP with SwiGLU activation.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
└──► up_proj ──► up [batch, seq_len, intermediate_size]
silu(gate) * up
down_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.gate_proj(x)
up = self.up_proj(x)
return self.down_proj(F.silu(gate) * up)
class Qwen3DecoderLayer(nn.Module):
"""Single Qwen3 Decoder Layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
layer_idx: int = 0,
):
super().__init__()
self.layer_idx = layer_idx
# Pre-attention LayerNorm
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# Self-attention
self.self_attn = Qwen3Attention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
attention_bias=attention_bias,
rms_norm_eps=rms_norm_eps,
layer_idx=layer_idx,
)
# Post-attention LayerNorm
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# MLP
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: Causal attention mask
past_key_value: KV cache for this layer
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V for debugging
Returns:
hidden_states: [batch, seq_len, hidden_size]
past_key_value: Updated cache
qkv_dict: QKV tensors (if output_qkv=True)
"""
# === Self Attention Block ===
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output, new_past_key_value, qkv_dict = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_qkv=output_qkv,
)
hidden_states = residual + attn_output
# === MLP Block ===
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, new_past_key_value, qkv_dict
class Qwen3Model(nn.Module):
"""Qwen3 Transformer Model (without LM head)."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
):
super().__init__()
self.vocab_size = vocab_size
self.num_hidden_layers = num_hidden_layers
# Token embeddings
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
# Decoder layers
self.layers = nn.ModuleList([
Qwen3DecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
layer_idx=i,
)
for i in range(num_hidden_layers)
])
# Final LayerNorm
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
position_ids: [batch, seq_len]
attention_mask: [batch, seq_len] or pre-computed 4D mask
past_key_values: List of (k, v) tuples for each layer
use_cache: Whether to return new cache
output_qkv_layers: List of layer indices to output QKV for
Returns:
hidden_states: [batch, seq_len, hidden_size]
new_past_key_values: Updated cache
qkv_outputs: {layer_idx: qkv_dict}
"""
batch_size, seq_len = input_ids.shape
# Embedding
hidden_states = self.embed_tokens(input_ids)
# Position IDs
if position_ids is None:
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Attention mask (create causal mask if not provided)
if attention_mask is None or attention_mask.dim() == 2:
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
causal_mask = torch.triu(
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
diagonal=kv_seq_len - seq_len + 1,
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
# Initialize cache list
new_past_key_values = [] if use_cache else None
qkv_outputs = {} if output_qkv_layers else None
# Decoder layers
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values else None
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
hidden_states, new_kv, qkv_dict = layer(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_kv,
use_cache=use_cache,
output_qkv=output_qkv,
)
if use_cache:
new_past_key_values.append(new_kv)
if qkv_dict is not None:
qkv_outputs[i] = qkv_dict
# Final norm
hidden_states = self.norm(hidden_states)
return hidden_states, new_past_key_values, qkv_outputs
class Qwen3ForCausalLM(nn.Module):
"""Qwen3 Model with Language Modeling head."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
tie_word_embeddings: bool = True,
):
super().__init__()
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
# Transformer model
self.model = Qwen3Model(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
)
# LM head
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
... (same as Qwen3Model)
Returns:
logits: [batch, seq_len, vocab_size]
past_key_values: Updated KV cache
qkv_outputs: QKV tensors for specified layers
"""
hidden_states, new_past_key_values, qkv_outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_qkv_layers=output_qkv_layers,
)
logits = self.lm_head(hidden_states)
return logits, new_past_key_values, qkv_outputs
@classmethod
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
"""
Load weights from a pretrained Qwen3 model.
Args:
model_path: Path to model directory containing config.json and model weights
dtype: Data type for model weights
Returns:
Initialized Qwen3ForCausalLM model
"""
import json
import os
from safetensors.torch import load_file
# Load config
config_path = os.path.join(model_path, "config.json")
with open(config_path) as f:
config = json.load(f)
# Create model
model = cls(
vocab_size=config["vocab_size"],
hidden_size=config["hidden_size"],
intermediate_size=config["intermediate_size"],
num_hidden_layers=config["num_hidden_layers"],
num_attention_heads=config["num_attention_heads"],
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
max_position_embeddings=config.get("max_position_embeddings", 32768),
rope_theta=config.get("rope_theta", 10000.0),
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
attention_bias=config.get("attention_bias", False),
mlp_bias=config.get("mlp_bias", False),
tie_word_embeddings=config.get("tie_word_embeddings", True),
)
# Load weights
weight_files = sorted([
f for f in os.listdir(model_path)
if f.endswith(".safetensors")
])
state_dict = {}
for wf in weight_files:
state_dict.update(load_file(os.path.join(model_path, wf)))
# Load into model
model.load_state_dict(state_dict, strict=False)
# Tie lm_head weights to embed_tokens if configured
if model.tie_word_embeddings:
model.lm_head.weight = model.model.embed_tokens.weight
model = model.to(dtype)
return model
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 32,
temperature: float = 1.0,
do_sample: bool = True,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> torch.Tensor:
"""Simple autoregressive generation."""
device = input_ids.device
batch_size, seq_len = input_ids.shape
past_key_values = None
generated = input_ids.clone()
for _ in range(max_new_tokens):
if past_key_values is None:
current_input = generated
else:
current_input = generated[:, -1:]
logits, past_key_values, _ = self(
input_ids=current_input,
past_key_values=past_key_values,
use_cache=True,
)
next_token_logits = logits[:, -1, :]
if temperature > 0 and do_sample:
next_token_logits = next_token_logits / temperature
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
def print_computation_graph():
"""Print the computation graph for reference."""
print(__doc__)
if __name__ == "__main__":
print_computation_graph()

212
tests/test_align.py Normal file
View File

@@ -0,0 +1,212 @@
"""
Test alignment between nanovllm and custom torch Qwen3 implementation.
Compares attention layer outputs to verify correctness.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch
from transformers import AutoTokenizer
from nanovllm import LLM, SamplingParams
from modeling_qwen3 import Qwen3ForCausalLM
from utils import generate_needle_prompt
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
INPUT_LEN = 512 # Use shorter length for alignment test
DTYPE = torch.float16
# Storage for captured tensors
nanovllm_outputs = {}
torch_outputs = {}
def make_nanovllm_hook(layer_id: int, storage: dict):
"""Capture nanovllm self_attn outputs (after o_proj)."""
def hook(module, inputs, output):
# Qwen3Attention output is a tuple (attn_output, None)
if isinstance(output, tuple):
attn_output = output[0]
else:
attn_output = output
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
if attn_output.dim() == 2:
attn_output = attn_output.unsqueeze(0)
storage[layer_id] = attn_output.detach().clone()
return hook
def make_torch_hook(layer_id: int, storage: dict):
"""Capture torch model self_attn outputs (after o_proj)."""
def hook(module, inputs, output):
# Qwen3Attention output is (attn_output, past_kv, qkv_dict)
attn_output, _, _ = output
storage[layer_id] = attn_output.detach().clone()
return hook
def compare_tensors(name: str, t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-2):
"""Compare two tensors and print statistics."""
# Handle shape differences
if t1.shape != t2.shape:
print(f"[{name}] Shape mismatch: {t1.shape} vs {t2.shape}")
# Try to reshape for comparison if possible
if t1.numel() == t2.numel():
t2 = t2.view(t1.shape)
else:
return False
diff = (t1.float() - t2.float()).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
passed = max_diff < atol
status = "PASS" if passed else "FAIL"
print(f"[{name}] {status}")
print(f" Shape: {list(t1.shape)}")
print(f" t1 mean: {t1.float().mean():.6f}, std: {t1.float().std():.6f}")
print(f" t2 mean: {t2.float().mean():.6f}, std: {t2.float().std():.6f}")
print(f" Max diff: {max_diff:.6f}, Mean diff: {mean_diff:.6f}")
return passed
# ============================================================
# Load nanovllm model
# ============================================================
print("=" * 60)
print("Loading nanovllm model...")
print("=" * 60)
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=4096,
max_num_batched_tokens=4096,
enable_cpu_offload=False, # Disable offload for alignment test
dtype="float16",
)
# ============================================================
# Load torch model
# ============================================================
print("\n" + "=" * 60)
print("Loading custom torch model...")
print("=" * 60)
torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
torch_model = torch_model.to("cuda")
torch_model.eval()
# ============================================================
# Generate test input
# ============================================================
print("\n" + "=" * 60)
print("Generating test input...")
print("=" * 60)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
prompt, _ = generate_needle_prompt(
tokenizer=tokenizer,
target_length=INPUT_LEN,
verbose=True,
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
print(f"Input shape: {input_ids.shape}")
# ============================================================
# Register hooks on both models
# ============================================================
print("\n" + "=" * 60)
print("Registering hooks...")
print("=" * 60)
# Hook on nanovllm (self_attn is Qwen3Attention, captures output after o_proj)
nanovllm_hooks = []
for layer_idx, layer in enumerate(llm.model_runner.model.model.layers):
if layer_idx >= 2: # Only first 2 layers
break
nanovllm_hooks.append(
layer.self_attn.register_forward_hook(
make_nanovllm_hook(layer_idx, nanovllm_outputs)
)
)
print(f" Registered nanovllm hook on layer {layer_idx} self_attn")
# Hook on torch model (self_attn is Qwen3Attention, captures output after o_proj)
torch_hooks = []
for layer_idx, layer in enumerate(torch_model.model.layers):
if layer_idx >= 2: # Only first 2 layers
break
torch_hooks.append(
layer.self_attn.register_forward_hook(
make_torch_hook(layer_idx, torch_outputs)
)
)
print(f" Registered torch hook on layer {layer_idx} self_attn")
# ============================================================
# Run nanovllm inference
# ============================================================
print("\n" + "=" * 60)
print("Running nanovllm inference...")
print("=" * 60)
# Use prompt_token_ids to ensure same input
prompt_token_ids = input_ids[0].tolist()
nanovllm_result = llm.generate(
[prompt_token_ids],
SamplingParams(temperature=0.01, max_tokens=1), # Near-greedy for determinism
use_tqdm=False,
)
# ============================================================
# Run torch inference
# ============================================================
print("\n" + "=" * 60)
print("Running torch inference...")
print("=" * 60)
with torch.no_grad():
torch_logits, _, _ = torch_model(input_ids)
# ============================================================
# Compare outputs
# ============================================================
print("\n" + "=" * 60)
print("Comparing attention outputs...")
print("=" * 60)
all_passed = True
for layer_idx in sorted(nanovllm_outputs.keys()):
if layer_idx not in torch_outputs:
print(f"[Layer {layer_idx}] Missing torch output")
all_passed = False
continue
nano_out = nanovllm_outputs[layer_idx]
torch_out = torch_outputs[layer_idx]
print(f"\n--- Layer {layer_idx} ---")
passed = compare_tensors(f"Layer {layer_idx} attn_output", nano_out, torch_out, atol=0.1)
all_passed = all_passed and passed
# ============================================================
# Cleanup
# ============================================================
for hook in nanovllm_hooks:
hook.remove()
for hook in torch_hooks:
hook.remove()
# ============================================================
# Result
# ============================================================
print("\n" + "=" * 60)
if all_passed:
print("test_align: PASSED - nanovllm and torch outputs aligned!")
else:
print("test_align: FAILED - outputs differ!")
print("=" * 60)

View File

@@ -12,151 +12,7 @@ os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
import argparse
from nanovllm import LLM, SamplingParams
# ============================================================
# Needle Test Generator
# ============================================================
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
) -> tuple[str, str]:
"""
Generate a needle-in-haystack prompt of approximately target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# Haystack filler paragraphs (various topics to create realistic context)
haystack_paragraphs = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Question at the end
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
# Estimate tokens for fixed parts
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
# Buffer for chat template, special tokens, etc.
overhead_tokens = 100 if use_chat_template else 50
# Available tokens for haystack
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
if haystack_target_tokens < 100:
raise ValueError(f"target_length {target_length} is too short for needle test")
# Build haystack by repeating paragraphs
haystack_parts = []
current_tokens = 0
para_idx = 0
while current_tokens < haystack_target_tokens:
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
if current_tokens + para_tokens > haystack_target_tokens:
break
haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1
# Calculate needle insertion point
needle_idx = int(len(haystack_parts) * needle_position)
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
# Insert needle
haystack_parts.insert(needle_idx, needle)
# Assemble prompt
full_text = "".join(haystack_parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
# Use chat template for instruct models
# For Qwen3, add /no_think to disable thinking mode
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
messages = [
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
# Verify length
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
import re
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
# Also try to find it as a standalone number
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
from utils import generate_needle_prompt, check_needle_answer
# ============================================================

View File

@@ -8,148 +8,9 @@ Uses standard HuggingFace inference (no custom KV cache, no offload).
import os
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ============================================================
# Needle Test Generator
# ============================================================
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
) -> tuple[str, str]:
"""
Generate a needle-in-haystack prompt of approximately target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# Haystack filler paragraphs (various topics to create realistic context)
haystack_paragraphs = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Estimate tokens for fixed parts
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
# Buffer for chat template, special tokens, etc.
overhead_tokens = 100 if use_chat_template else 50
# Available tokens for haystack
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
if haystack_target_tokens < 100:
raise ValueError(f"target_length {target_length} is too short for needle test")
# Build haystack by repeating paragraphs
haystack_parts = []
current_tokens = 0
para_idx = 0
while current_tokens < haystack_target_tokens:
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
if current_tokens + para_tokens > haystack_target_tokens:
break
haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1
# Calculate needle insertion point
needle_idx = int(len(haystack_parts) * needle_position)
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
# Insert needle
haystack_parts.insert(needle_idx, needle)
# Assemble prompt
full_text = "".join(haystack_parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
# Use chat template for instruct models
# For Qwen3, add /no_think to disable thinking mode
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
messages = [
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
# Verify length
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
import re
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
from transformers import AutoTokenizer
from modeling_qwen3 import Qwen3ForCausalLM
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
@@ -207,22 +68,19 @@ def run_needle_test(
# 3. Load model
print("[3/4] Loading model...")
torch_dtype = {
"auto": "auto",
"auto": torch.float16, # default to float16 for custom model
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}.get(dtype, "auto")
}.get(dtype, torch.float16)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
device_map="auto",
trust_remote_code=True,
)
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
# 4. Generate output
print("[4/4] Running inference...")
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
device = next(model.parameters()).device
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
print(f" Input shape: {input_ids.shape}")
with torch.no_grad():

175
tests/utils.py Normal file
View File

@@ -0,0 +1,175 @@
"""
Test utilities for nano-vllm.
"""
import re
from typing import Tuple
# ============================================================
# Needle-in-Haystack Test Utilities
# ============================================================
# Haystack filler paragraphs (various topics to create realistic context)
HAYSTACK_PARAGRAPHS = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
verbose: bool = True,
) -> Tuple[str, str]:
"""
Generate a needle-in-haystack prompt of approximately target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
verbose: Whether to print generation info
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Question at the end
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
# Estimate tokens for fixed parts
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
# Buffer for chat template, special tokens, etc.
overhead_tokens = 100 if use_chat_template else 50
# Available tokens for haystack
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
if haystack_target_tokens < 100:
raise ValueError(f"target_length {target_length} is too short for needle test")
# Build haystack by repeating paragraphs
haystack_parts = []
current_tokens = 0
para_idx = 0
while current_tokens < haystack_target_tokens:
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
if current_tokens + para_tokens > haystack_target_tokens:
break
haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1
# Calculate needle insertion point
needle_idx = int(len(haystack_parts) * needle_position)
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
# Insert needle
haystack_parts.insert(needle_idx, needle)
# Assemble prompt
full_text = "".join(haystack_parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
# Use chat template for instruct models
# For Qwen3, add /no_think to disable thinking mode
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
messages = [
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
# Verify length
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
if verbose:
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
# Also try to find it as a standalone number
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
def generate_random_token_ids(
length: int,
vocab_size: int = 10000,
seed: int = 42,
) -> list:
"""
Generate random token IDs for testing.
Args:
length: Number of tokens to generate
vocab_size: Maximum token ID (exclusive)
seed: Random seed for reproducibility
Returns:
List of random token IDs
"""
from random import randint, seed as set_seed
set_seed(seed)
return [randint(0, vocab_size - 1) for _ in range(length)]