Compare commits

..

11 Commits

Author SHA1 Message Date
Zijie Tian
247c5312d9 [fix] Fixed decode misalign. 2026-01-05 19:00:44 +08:00
Zijie Tian
054aaff403 [fix] Fixed needle test bug. 2026-01-05 18:34:09 +08:00
Zijie Tian
d623043a3c [WIP] FIXED decode and prefill NEEDLE test. 2026-01-05 01:51:46 +08:00
Zijie Tian
e897380127 [test] Added test_align.py and Before change nanovllm attention. 2026-01-04 22:48:01 +08:00
Zijie Tian
24096431ed [refactor] refactor test_align.py. 2026-01-04 20:55:40 +08:00
Zijie Tian
772313db8f [refactor] Refactor the kvcache offload. 2026-01-04 19:37:03 +08:00
Zijie Tian
00ed17c640 [feat] Added debug tools. 2026-01-03 22:36:40 +08:00
Zijie Tian
9b52d25866 [docs] Update CLAUDE.md. 2026-01-03 20:46:00 +08:00
Zijie Tian
8c3418725b [refactor] Refactor needle test. 2026-01-03 19:19:37 +08:00
Zijie Tian
b3685c9190 [test] Added test_align.py 2026-01-03 18:55:58 +08:00
Zijie Tian
6927a75ac3 [refactor] refactor needle.py. 2026-01-03 18:33:48 +08:00
24 changed files with 2793 additions and 361 deletions

View File

@@ -20,6 +20,80 @@ For sparse attention related content (block sparse attention, MInference, FlexPr
- **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096
- **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload
## PyTorch Hooks for Debugging
### Hook Positions in Qwen3
```
decoder_layer
├── input_layernorm (RMSNorm)
├── self_attn (Qwen3Attention) ← Hook here for attention I/O after o_proj
│ ├── q_proj → q_norm → RoPE
│ ├── k_proj → k_norm → RoPE
│ ├── v_proj
│ ├── attn (Attention) ← Hook here for Q/K/V tensors
│ │ └── FlashAttention / SDPA
│ └── o_proj
├── post_attention_layernorm (RMSNorm)
└── mlp (Qwen3MLP)
```
### Hook Types & Data Shapes
| Hook Position | Type | Captured Data |
|---------------|------|---------------|
| `self_attn` | post | `[batch, seq_len, hidden_size]` - after o_proj |
| `self_attn.attn` | pre | Q,K,V: `[seq_len, num_heads, head_dim]` - after RoPE |
| `self_attn.attn` | post | `[seq_len, num_heads, head_dim]` - before o_proj |
### Example: Capture Attention Outputs
```python
storage = {}
def make_hook(layer_id: int, storage: dict):
def hook(module, inputs, output):
if isinstance(output, tuple):
attn_output = output[0]
else:
attn_output = output
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
if attn_output.dim() == 2:
attn_output = attn_output.unsqueeze(0)
storage[layer_id] = attn_output.detach().clone()
return hook
# Register hooks
hooks = []
for layer_idx, layer in enumerate(model.model.layers):
hooks.append(layer.self_attn.register_forward_hook(make_hook(layer_idx, storage)))
# Run inference...
# Cleanup
for hook in hooks:
hook.remove()
```
### Alignment Testing
Use `tests/test_align.py` to compare nanovllm with reference torch implementation:
```bash
python tests/test_align.py
```
Key files:
- `tests/modeling_qwen3.py`: Reference Qwen3 implementation (torch + transformers only)
- `tests/test_align.py`: Compares attention outputs between nanovllm and reference
- `tests/test_needle_ref.py`: Reference needle test using custom Qwen3
### Common Pitfalls
1. **Shape mismatch**: nanovllm uses `[num_tokens, ...]` while torch uses `[batch, seq_len, ...]`
2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj
3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]`
## CPU Offload System
### Ring Buffer Design
@@ -228,7 +302,7 @@ def _merge_output_kernel(...):
| Parameter | Default | Notes |
|-----------|---------|-------|
| `kvcache_block_size` | 4096 | Tokens per block |
| `kvcache_block_size` | 1024 | Tokens per block |
| `max_num_batched_tokens` | 16384 | Set = max_model_len for long context |
| `gpu_memory_utilization` | 0.9 | GPU memory fraction |
| `enable_cpu_offload` | False | Enable for long context |

103
DEBUG_SUMMARY.md Normal file
View File

@@ -0,0 +1,103 @@
# Chunked Prefill Bug Debug Summary
## Problem
`test_needle.py --enable-offload --input-len 8192` fails with garbage output.
The model generates completely wrong tokens instead of the expected "7492".
## Investigation Progress
### 1. Stream Synchronization Fix (Completed)
- Replaced Triton `store_kvcache` kernel with pure PyTorch operations
- Moved `store_kvcache` to `compute_stream` in chunked prefill mode
- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload
- Added sync: `default_stream.wait_stream(compute_stream)` before return
### 2. KV Cache Alignment Verification (Completed)
Created alignment tests to compare K/V tensors between torch reference and nanovllm:
**RoPE Alignment:**
- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0)
- Confirmed RoPE is NOT the cause of the bug
**K/V Cache Alignment (Chunk 0):**
- Cosine similarity: ~1.0 for all layers
- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision)
- Mean diff: < 0.001
- **Conclusion: K/V cache offload is working correctly**
### 3. Layer Output Divergence Analysis (Completed)
Created per-chunk layer output comparison:
**Chunk 0 (tokens 0-4096):**
- All layers pass with excellent cosine similarity (0.999+)
- Max diff grows in later layers but within acceptable range
**Chunk 1 (tokens 4096-8192):**
- Layers 0-19: OK (cosine ~1.0)
- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114)
- Divergence correlates with later transformer layers
### 4. Critical Discovery: Single-Chunk Offload Also Fails
**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled.
```
# Without offload: PASSES
python tests/test_needle.py --input-len 2048
# Output: "7492" (correct)
# With offload: FAILS
python tests/test_needle.py --enable-offload --input-len 2048
# Output: "The Ble White Th G Lopsiswin..." (garbage)
```
**This proves the bug is NOT in:**
- Chunked attention logic (merge_attention_outputs)
- Multi-chunk KV loading
- Ring buffer pipeline
**The bug IS in:**
- The decode path when CPU offload is enabled
- How prefilled KV is loaded/used during decode
### 5. Decode Path Analysis (In Progress)
The decode path in CPU offload mode:
1. Prefill writes KV to GPU, offloads to CPU
2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline`
3. Attend to prefilled KV + accumulated decode tokens
4. Merge results
**Observations:**
- `prefilled_blocks` set is empty after decode (should contain block IDs)
- CPU cache has valid data (reasonable mean/std values)
- Decode buffer has zeros (decode tokens not being stored correctly?)
## Current Status
### Working
- Stream synchronization fixes
- K/V cache offload to CPU (verified alignment)
- RoPE implementation
- Chunked prefill attention for first chunk
### Not Working
- Decode with CPU offload (even for single-chunk inputs)
- Multi-chunk attention (divergence in later layers for chunk 1)
## Next Steps
1. Debug why `prefilled_blocks` is empty after decode
2. Check if decode path correctly loads KV from CPU
3. Verify decode buffer is being written correctly
4. Compare decode attention outputs between offload and non-offload modes
## Key Files
- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths
- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine
- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks`
- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration
## Hypothesis
The decode path fails because:
1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty
2. OR the decode attention is not correctly loading/using the prefilled KV from CPU
3. OR there's a stream synchronization issue specific to decode path

View File

@@ -15,7 +15,7 @@ class Config:
enforce_eager: bool = False
hf_config: AutoConfig | None = None
eos: int = -1
kvcache_block_size: int = 4096
kvcache_block_size: int = 1024
num_kvcache_blocks: int = -1
dtype: str | None = None # "float16", "bfloat16", or None (use model default)

View File

@@ -0,0 +1,49 @@
"""
Breakpoint debugging tools for aligning nanovllm with reference implementations.
This module provides a generator-based breakpoint aligner that enables step-by-step
comparison between nanovllm and torch reference model outputs.
Example:
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
>>> from tests.modeling_qwen3 import Qwen3ForCausalLM
>>>
>>> # Load models
>>> torch_model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch.float16)
>>> nanovllm_model = ... # Your nanovllm model
>>>
>>> # Create adapters
>>> ref = TorchSteppable(torch_model)
>>> test = NanovllmSteppable(nanovllm_model)
>>>
>>> # Run alignment
>>> aligner = BreakpointAligner(ref, test)
>>> result = aligner.align(input_ids)
>>> print(result)
"""
from .breakpoints import BreakpointType, Breakpoint
from .comparator import TensorComparator, ComparisonResult
from .aligner import BreakpointAligner, AlignmentResult
from .adapters import SteppableModel, TorchSteppable, NanovllmSteppable
from .utils import setup_prefill_context, setup_decode_context, cleanup_context
__all__ = [
# Core classes
"BreakpointAligner",
"AlignmentResult",
# Breakpoints
"BreakpointType",
"Breakpoint",
# Comparator
"TensorComparator",
"ComparisonResult",
# Adapters
"SteppableModel",
"TorchSteppable",
"NanovllmSteppable",
# Utils
"setup_prefill_context",
"setup_decode_context",
"cleanup_context",
]

View File

@@ -0,0 +1,11 @@
"""Model adapters for breakpoint alignment."""
from .base import SteppableModel
from .torch_adapter import TorchSteppable
from .nanovllm_adapter import NanovllmSteppable
__all__ = [
"SteppableModel",
"TorchSteppable",
"NanovllmSteppable",
]

View File

@@ -0,0 +1,59 @@
"""Base class for steppable model adapters."""
from abc import ABC, abstractmethod
from typing import Generator, Set, Optional
import torch
from ..breakpoints import Breakpoint, BreakpointType
class SteppableModel(ABC):
"""
Abstract base class for models that can yield at breakpoints.
Subclasses implement the step() method as a generator that yields
Breakpoint objects at each enabled breakpoint during forward pass.
"""
def __init__(self, enabled_breakpoints: Optional[Set[BreakpointType]] = None):
"""
Args:
enabled_breakpoints: Set of breakpoint types to yield at.
If None, yields at all breakpoints.
"""
self.enabled_breakpoints = enabled_breakpoints
def is_enabled(self, bp_type: BreakpointType) -> bool:
"""Check if a breakpoint type is enabled."""
if self.enabled_breakpoints is None:
return True
return bp_type in self.enabled_breakpoints
@abstractmethod
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Generator that yields Breakpoint objects at enabled breakpoints.
Args:
input_ids: Input token IDs
positions: Position IDs (optional, auto-generated if None)
is_prefill: True for prefill phase, False for decode
Yields:
Breakpoint objects at each enabled checkpoint
Returns:
Final output tensor (logits)
"""
pass
@property
@abstractmethod
def num_layers(self) -> int:
"""Return the number of decoder layers."""
pass

View File

@@ -0,0 +1,235 @@
"""Nanovllm model adapter for breakpoint alignment."""
from typing import Generator, Set, Optional, Dict, Any, List
import torch
from nanovllm.utils.context import set_context, reset_context
from ..breakpoints import Breakpoint, BreakpointType
from .base import SteppableModel
class NanovllmSteppable(SteppableModel):
"""
Steppable adapter for nanovllm Qwen3 implementation.
Uses PyTorch hooks to capture intermediate values during forward pass,
then yields them as breakpoints after execution completes.
Key challenges handled:
1. Shape difference: nanovllm uses [num_tokens, hidden] vs [batch, seq, hidden]
2. Context-based attention: must call set_context() before forward
3. Fused operations: decoder layer returns (hidden_states, residual) tuple
"""
def __init__(
self,
model,
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
):
"""
Args:
model: Qwen3ForCausalLM from nanovllm
enabled_breakpoints: Set of breakpoint types to yield at
"""
super().__init__(enabled_breakpoints)
self.model = model
self.model.eval()
self._hooks: List[Any] = []
self._captured: Dict[str, torch.Tensor] = {}
@property
def num_layers(self) -> int:
return len(self.model.model.layers)
def _register_hooks(self):
"""Register forward hooks on all relevant modules."""
self._hooks = []
self._captured = {}
# Hook for embedding output
def embed_hook(module, input, output):
self._captured["embed"] = output.detach().clone()
self._hooks.append(
self.model.model.embed_tokens.register_forward_hook(embed_hook)
)
# Hooks for each decoder layer
for layer_idx in range(self.num_layers):
layer = self.model.model.layers[layer_idx]
def make_layer_hook(idx):
def hook(module, input, output):
# Decoder layer returns (hidden_states, residual)
# hidden_states is MLP output, residual is accumulated residual
# To match torch reference, we need hidden_states + residual
if isinstance(output, tuple) and len(output) >= 2:
hidden_states, residual = output[0], output[1]
full_output = hidden_states + residual
else:
full_output = output
self._captured[f"layer_{idx}"] = full_output.detach().clone()
return hook
self._hooks.append(
layer.register_forward_hook(make_layer_hook(layer_idx))
)
# Hook for final norm
def final_norm_hook(module, input, output):
# Final norm returns (hidden_states, _) for fused add
hidden_states = output[0] if isinstance(output, tuple) else output
self._captured["final_norm"] = hidden_states.detach().clone()
self._hooks.append(
self.model.model.norm.register_forward_hook(final_norm_hook)
)
# Hook for lm_head
def lm_head_hook(module, input, output):
self._captured["lm_head"] = output.detach().clone()
self._hooks.append(
self.model.lm_head.register_forward_hook(lm_head_hook)
)
def _remove_hooks(self):
"""Remove all registered hooks."""
for hook in self._hooks:
hook.remove()
self._hooks = []
def _setup_context(self, seq_len: int, device: torch.device, is_prefill: bool):
"""
Set up nanovllm context for forward pass.
For alignment testing, we use simple context without real KV cache.
"""
if is_prefill:
# Prefill: process all tokens at once
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
# Use -1 for slot_mapping to skip KV cache writes
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
slot_mapping=slot_mapping,
is_chunked_prefill=False,
)
else:
# Decode: single token generation
# For decode, we need context_lens and block_tables
# For alignment testing without real KV cache, we use minimal setup
context_lens = torch.tensor([seq_len - 1], dtype=torch.int32, device=device)
# Single token slot
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
# Empty block tables (no KV cache)
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
)
def _normalize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Normalize nanovllm tensor shape to [batch, seq_len, ...].
nanovllm uses [num_tokens, ...] format without batch dimension.
We add batch dimension for comparison with torch model.
"""
if tensor.dim() == 2: # [num_tokens, hidden_size]
return tensor.unsqueeze(0)
return tensor
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Execute nanovllm forward pass with hooks to capture breakpoints.
Unlike the torch adapter which manually steps through each component,
we run the full forward pass and collect captured values afterward.
"""
# Ensure 1D for nanovllm (it expects [num_tokens])
if input_ids.dim() == 2:
input_ids = input_ids.squeeze(0)
seq_len = input_ids.numel()
device = input_ids.device
# Generate position IDs if not provided
if positions is None:
positions = torch.arange(seq_len, device=device)
elif positions.dim() == 2:
positions = positions.squeeze(0)
# Register hooks
self._register_hooks()
try:
# Setup context for attention
self._setup_context(seq_len, device, is_prefill)
# Run forward pass (hooks capture everything)
with torch.no_grad():
hidden_states = self.model(input_ids, positions)
logits = self.model.compute_logits(hidden_states)
reset_context()
# Yield breakpoints in order from captured data
# EMBEDDING
if self.is_enabled(BreakpointType.EMBEDDING) and "embed" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.EMBEDDING,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["embed"]),
name="Embedding",
)
# LAYER_OUTPUT for each layer
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
for layer_idx in range(self.num_layers):
key = f"layer_{layer_idx}"
if key in self._captured:
yield Breakpoint(
bp_type=BreakpointType.LAYER_OUTPUT,
layer_idx=layer_idx,
tensor=self._normalize_tensor(self._captured[key]),
name=f"Layer {layer_idx}",
)
# FINAL_NORM
if self.is_enabled(BreakpointType.FINAL_NORM) and "final_norm" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.FINAL_NORM,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["final_norm"]),
name="Final Norm",
)
# LM_HEAD
if self.is_enabled(BreakpointType.LM_HEAD) and "lm_head" in self._captured:
yield Breakpoint(
bp_type=BreakpointType.LM_HEAD,
layer_idx=None,
tensor=self._normalize_tensor(self._captured["lm_head"]),
name="LM Head",
)
return logits
finally:
self._remove_hooks()
self._captured = {}

View File

@@ -0,0 +1,119 @@
"""Torch reference model adapter for breakpoint alignment."""
from typing import Generator, Set, Optional
import torch
from ..breakpoints import Breakpoint, BreakpointType
from .base import SteppableModel
class TorchSteppable(SteppableModel):
"""
Steppable adapter for the torch reference Qwen3 implementation.
Wraps tests/modeling_qwen3.py Qwen3ForCausalLM model.
"""
def __init__(
self,
model,
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
):
"""
Args:
model: Qwen3ForCausalLM from tests/modeling_qwen3.py
enabled_breakpoints: Set of breakpoint types to yield at
"""
super().__init__(enabled_breakpoints)
self.model = model
self.model.eval()
@property
def num_layers(self) -> int:
return len(self.model.model.layers)
def step(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> Generator[Breakpoint, None, torch.Tensor]:
"""
Generator that manually steps through the torch model.
The torch model uses [batch, seq_len, hidden_size] shapes.
"""
# Ensure batch dimension
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Generate position IDs if not provided
if positions is None:
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
elif positions.dim() == 1:
positions = positions.unsqueeze(0)
with torch.no_grad():
# === EMBEDDING ===
hidden_states = self.model.model.embed_tokens(input_ids)
if self.is_enabled(BreakpointType.EMBEDDING):
yield Breakpoint(
bp_type=BreakpointType.EMBEDDING,
layer_idx=None,
tensor=hidden_states.detach().clone(),
name="Embedding",
)
# Create causal attention mask
causal_mask = torch.triu(
torch.full((seq_len, seq_len), float("-inf"), device=device),
diagonal=1,
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0)
# === DECODER LAYERS ===
for layer_idx, layer in enumerate(self.model.model.layers):
hidden_states, _, _ = layer(
hidden_states=hidden_states,
position_ids=positions,
attention_mask=attention_mask,
past_key_value=None,
use_cache=False,
output_qkv=False,
)
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
yield Breakpoint(
bp_type=BreakpointType.LAYER_OUTPUT,
layer_idx=layer_idx,
tensor=hidden_states.detach().clone(),
name=f"Layer {layer_idx}",
)
# === FINAL NORM ===
hidden_states = self.model.model.norm(hidden_states)
if self.is_enabled(BreakpointType.FINAL_NORM):
yield Breakpoint(
bp_type=BreakpointType.FINAL_NORM,
layer_idx=None,
tensor=hidden_states.detach().clone(),
name="Final Norm",
)
# === LM HEAD ===
logits = self.model.lm_head(hidden_states)
if self.is_enabled(BreakpointType.LM_HEAD):
yield Breakpoint(
bp_type=BreakpointType.LM_HEAD,
layer_idx=None,
tensor=logits.detach().clone(),
name="LM Head",
)
return logits

211
nanovllm/debug/aligner.py Normal file
View File

@@ -0,0 +1,211 @@
"""Breakpoint aligner for comparing model outputs."""
from dataclasses import dataclass, field
from typing import List, Tuple, Optional
import torch
from .breakpoints import Breakpoint
from .comparator import TensorComparator, ComparisonResult
from .adapters.base import SteppableModel
@dataclass
class AlignmentResult:
"""Result of an alignment test."""
passed: bool
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = field(default_factory=list)
failed_at: Optional[Breakpoint] = None
message: str = ""
def __repr__(self) -> str:
passed_count = sum(1 for _, _, c in self.all_comparisons if c.passed)
total = len(self.all_comparisons)
status = "PASSED" if self.passed else "FAILED"
return f"AlignmentResult({status}, {passed_count}/{total} breakpoints passed)"
class BreakpointAligner:
"""
Orchestrates alternating execution of reference and test models,
comparing outputs at each breakpoint.
Example:
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
>>> ref = TorchSteppable(torch_model)
>>> test = NanovllmSteppable(nanovllm_model)
>>> aligner = BreakpointAligner(ref, test)
>>> result = aligner.align(input_ids)
>>> print(result)
"""
def __init__(
self,
ref_model: SteppableModel,
test_model: SteppableModel,
comparator: Optional[TensorComparator] = None,
stop_on_error: bool = True,
verbose: bool = True,
):
"""
Args:
ref_model: Reference (torch) steppable model
test_model: Test (nanovllm) steppable model
comparator: Tensor comparator instance (uses default if None)
stop_on_error: If True, stop at first mismatch
verbose: If True, print comparison results
"""
self.ref_model = ref_model
self.test_model = test_model
self.comparator = comparator or TensorComparator()
self.stop_on_error = stop_on_error
self.verbose = verbose
def align(
self,
input_ids: torch.Tensor,
positions: Optional[torch.Tensor] = None,
is_prefill: bool = True,
) -> AlignmentResult:
"""
Run both models with same input, comparing at each breakpoint.
Args:
input_ids: Input token IDs
positions: Position IDs (optional)
is_prefill: True for prefill phase, False for decode
Returns:
AlignmentResult with pass/fail status and details
"""
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = []
if self.verbose:
phase = "prefill" if is_prefill else "decode"
print(f"\n{'='*60}")
print(f"Alignment Test ({phase})")
print(f"{'='*60}")
# Start both generators
ref_gen = self.ref_model.step(input_ids, positions, is_prefill)
test_gen = self.test_model.step(input_ids, positions, is_prefill)
try:
while True:
# Get next breakpoint from reference
try:
ref_bp = next(ref_gen)
except StopIteration:
break
# Get corresponding breakpoint from test
try:
test_bp = next(test_gen)
except StopIteration:
if self.verbose:
print(f"Test model ended early at {ref_bp.name}")
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=f"Test model ended early at {ref_bp.name}",
)
# Verify breakpoints match
if ref_bp.bp_type != test_bp.bp_type:
msg = f"Breakpoint type mismatch: {ref_bp.bp_type} vs {test_bp.bp_type}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
if ref_bp.layer_idx != test_bp.layer_idx:
msg = f"Layer index mismatch: {ref_bp.layer_idx} vs {test_bp.layer_idx}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
# Normalize shapes for comparison
ref_t = ref_bp.normalize_shape()
test_t = test_bp.normalize_shape()
# Handle shape mismatches
if ref_t.shape != test_t.shape:
if self.verbose:
print(f"[{ref_bp.name}] Shape mismatch: ref={ref_t.shape} vs test={test_t.shape}")
# Try to reshape if element count matches
if ref_t.numel() == test_t.numel():
test_t = test_t.view(ref_t.shape)
else:
msg = f"Shape mismatch at {ref_bp.name}: {ref_t.shape} vs {test_t.shape}"
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=msg,
)
# Compare tensors
result = self.comparator.compare(ref_t, test_t, ref_bp.name)
all_comparisons.append((ref_bp, test_bp, result))
if self.verbose:
status = "\u2713" if result.passed else "\u2717"
print(f"{status} [{ref_bp.name}] cos={result.cosine_similarity:.6f}, max_diff={result.max_abs_diff:.2e}")
if not result.passed and self.stop_on_error:
if self.verbose:
print(f"\nStopped at {ref_bp.name} (stop_on_error=True)")
print(result.message)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
failed_at=ref_bp,
message=f"Alignment failed at {ref_bp.name}",
)
# Check for extra test breakpoints
try:
extra_bp = next(test_gen)
msg = f"Test model has extra breakpoints starting at {extra_bp.name}"
if self.verbose:
print(msg)
return AlignmentResult(
passed=False,
all_comparisons=all_comparisons,
message=msg,
)
except StopIteration:
pass
except Exception as e:
msg = f"Exception during alignment: {str(e)}"
if self.verbose:
print(msg)
raise
# Summary
all_passed = all(comp[2].passed for comp in all_comparisons)
passed_count = sum(1 for _, _, c in all_comparisons if c.passed)
total = len(all_comparisons)
if self.verbose:
print(f"{'='*60}")
status = "PASSED" if all_passed else "FAILED"
print(f"Result: {status} ({passed_count}/{total} breakpoints)")
print(f"{'='*60}\n")
return AlignmentResult(
passed=all_passed,
all_comparisons=all_comparisons,
message="All breakpoints aligned" if all_passed else "Some breakpoints failed",
)

View File

@@ -0,0 +1,39 @@
"""Breakpoint types and data structures for alignment debugging."""
from dataclasses import dataclass
from enum import Enum, auto
from typing import Optional
import torch
class BreakpointType(Enum):
"""Types of breakpoints in the model forward pass."""
EMBEDDING = auto() # After embed_tokens
LAYER_OUTPUT = auto() # After each decoder layer
FINAL_NORM = auto() # After final RMSNorm
LM_HEAD = auto() # After lm_head (logits)
@dataclass
class Breakpoint:
"""A captured breakpoint with tensor data."""
bp_type: BreakpointType
layer_idx: Optional[int] # None for EMBEDDING, FINAL_NORM, LM_HEAD
tensor: torch.Tensor
name: str
def normalize_shape(self) -> torch.Tensor:
"""
Normalize tensor shape for comparison.
nanovllm uses [num_tokens, hidden_size] while torch uses
[batch, seq_len, hidden_size]. This adds a batch dimension
to 2D tensors for comparison.
"""
if self.tensor.dim() == 2:
return self.tensor.unsqueeze(0)
return self.tensor
def __repr__(self) -> str:
shape_str = "x".join(str(d) for d in self.tensor.shape)
return f"Breakpoint({self.name}, shape={shape_str}, dtype={self.tensor.dtype})"

View File

@@ -0,0 +1,94 @@
"""Tensor comparison utilities for alignment debugging."""
from dataclasses import dataclass
import torch
import torch.nn.functional as F
@dataclass
class ComparisonResult:
"""Result of comparing two tensors."""
passed: bool
cosine_similarity: float
max_abs_diff: float
mean_abs_diff: float
message: str
def __repr__(self) -> str:
status = "\u2713" if self.passed else "\u2717"
return f"{status} cos={self.cosine_similarity:.6f}, max_diff={self.max_abs_diff:.2e}"
class TensorComparator:
"""Compares tensors using cosine similarity and absolute differences."""
def __init__(
self,
cosine_threshold: float = 0.999,
max_diff_threshold: float = 0.1,
mean_diff_threshold: float = 0.01,
):
"""
Args:
cosine_threshold: Minimum cosine similarity to pass (0-1)
max_diff_threshold: Maximum allowed absolute difference
mean_diff_threshold: Maximum allowed mean absolute difference
"""
self.cosine_threshold = cosine_threshold
self.max_diff_threshold = max_diff_threshold
self.mean_diff_threshold = mean_diff_threshold
def compare(
self,
ref: torch.Tensor,
test: torch.Tensor,
name: str = "",
) -> ComparisonResult:
"""
Compare two tensors and return detailed result.
Args:
ref: Reference tensor
test: Test tensor
name: Name for the comparison (used in message)
Returns:
ComparisonResult with pass/fail status and metrics
"""
# Convert to float32 for comparison
ref_f = ref.float().flatten()
test_f = test.float().flatten()
# Cosine similarity
cos_sim = F.cosine_similarity(
ref_f.unsqueeze(0),
test_f.unsqueeze(0)
).item()
# Absolute differences
diff = (ref.float() - test.float()).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
# Check thresholds
passed = (
cos_sim >= self.cosine_threshold and
max_diff <= self.max_diff_threshold and
mean_diff <= self.mean_diff_threshold
)
status = "PASS" if passed else "FAIL"
message = (
f"[{name}] {status}\n"
f" Cosine Similarity: {cos_sim:.6f} (threshold: {self.cosine_threshold})\n"
f" Max Abs Diff: {max_diff:.6f} (threshold: {self.max_diff_threshold})\n"
f" Mean Abs Diff: {mean_diff:.6f} (threshold: {self.mean_diff_threshold})"
)
return ComparisonResult(
passed=passed,
cosine_similarity=cos_sim,
max_abs_diff=max_diff,
mean_abs_diff=mean_diff,
message=message,
)

51
nanovllm/debug/utils.py Normal file
View File

@@ -0,0 +1,51 @@
"""Utility functions for breakpoint alignment debugging."""
import torch
from nanovllm.utils.context import set_context, reset_context
def setup_prefill_context(seq_len: int, device: torch.device):
"""
Set up nanovllm context for prefill alignment testing.
Args:
seq_len: Sequence length
device: Target device
"""
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
set_context(
is_prefill=True,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
slot_mapping=slot_mapping,
is_chunked_prefill=False,
)
def setup_decode_context(context_len: int, device: torch.device):
"""
Set up nanovllm context for decode alignment testing.
Args:
context_len: Context length (number of previous tokens)
device: Target device
"""
context_lens = torch.tensor([context_len], dtype=torch.int32, device=device)
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
set_context(
is_prefill=False,
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
)
def cleanup_context():
"""Reset nanovllm context after alignment testing."""
reset_context()

View File

@@ -35,7 +35,10 @@ class ModelRunner:
self.model = Qwen3ForCausalLM(hf_config)
load_model(self.model, config.model)
self.sampler = GreedySampler()
self.warmup_model()
#> Disable warmup for debugging
# self.warmup_model()
self.allocate_kv_cache()
if not self.enforce_eager:
self.capture_cudagraph()
@@ -194,7 +197,7 @@ class ModelRunner:
f"block_size={self.block_size}"
)
# Bind layer caches to attention modules and set layer_id
#> Bind layer caches to attention modules and set layer_id
layer_id = 0
for module in self.model.modules():
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
@@ -480,7 +483,7 @@ class ModelRunner:
if input_ids.numel() == 0:
break
# Run model forward
#> Run model forward
logits = self.run_model(input_ids, positions, is_prefill=True)
reset_context()

View File

@@ -12,7 +12,7 @@ class SequenceStatus(Enum):
class Sequence:
block_size = 4096
block_size = 1024
counter = count()
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
@@ -34,6 +34,14 @@ class Sequence:
def __getitem__(self, key):
return self.token_ids[key]
def __repr__(self):
ids = self.token_ids
if len(ids) > 20:
ids_str = "[" + ", ".join(map(str, ids[:10])) + ", ..., " + ", ".join(map(str, ids[-5:])) + "]"
else:
ids_str = str(ids)
return f"Seq(id={self.seq_id}, status={self.status.name}, tokens={self.num_tokens}, ids={ids_str})"
@property
def is_finished(self):
return self.status == SequenceStatus.FINISHED

View File

@@ -146,6 +146,10 @@ class HybridKVCacheManager(KVCacheManager):
# Key: sequence id, Value: starting position where decode began in current block
self._decode_start_pos: Dict[int, int] = {}
# Track original prefill length (for correct last_block_valid_tokens calculation)
# Key: sequence id, Value: number of tokens from prefill (before decode started)
self._prefill_len: Dict[int, int] = {}
# Sparse attention policy (optional)
self.sparse_policy: Optional["SparsePolicy"] = None
@@ -542,6 +546,26 @@ class HybridKVCacheManager(KVCacheManager):
seq_id = id(seq)
self._decode_start_pos[seq_id] = 0
def get_prefill_len(self, seq: Sequence) -> int:
"""
Get the original prefill length for a sequence.
This is cached on first call to ensure correct last_block_valid_tokens
calculation during decode (the CPU blocks don't change after prefill).
Args:
seq: Sequence
Returns:
Number of tokens from prefill (before decode started)
"""
seq_id = id(seq)
if seq_id not in self._prefill_len:
# First decode step - store the prefill length
# len(seq) - 1 because current len includes the first decode token
self._prefill_len[seq_id] = len(seq) - 1
return self._prefill_len[seq_id]
def clear_decode_tracking(self, seq: Sequence) -> None:
"""
Clear decode position tracking for sequence.
@@ -553,6 +577,7 @@ class HybridKVCacheManager(KVCacheManager):
"""
seq_id = id(seq)
self._decode_start_pos.pop(seq_id, None)
self._prefill_len.pop(seq_id, None)
def __repr__(self) -> str:
return (

View File

@@ -39,7 +39,7 @@ class PolicyContext:
is_prefill: bool
"""True if in prefill phase, False if in decode phase."""
block_size: int = 4096
block_size: int = 1024
"""Number of tokens per block."""
total_kv_len: int = 0

View File

@@ -2,8 +2,6 @@ import logging
import torch
import torch.cuda.nvtx
from torch import nn
import triton
import triton.language as tl
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
@@ -12,37 +10,49 @@ from nanovllm.kvcache.sparse.policy import PolicyContext
logger = logging.getLogger(__name__)
@triton.jit
def store_kvcache_kernel(
key_ptr,
key_stride,
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr,
D: tl.constexpr,
def store_kvcache(
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1: return
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets)
value = tl.load(value_ptr + value_offsets)
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key)
tl.store(v_cache_ptr + cache_offsets, value)
"""
Store key/value tensors into KV cache using slot mapping.
This is a pure PyTorch implementation replacing the previous Triton kernel.
Uses index_copy_ for efficient in-place scatter operation.
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
assert key.stride(1) == head_dim and value.stride(1) == head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D
assert slot_mapping.numel() == N
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
Args:
key: [N, num_kv_heads, head_dim]
value: [N, num_kv_heads, head_dim]
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] or similar
v_cache: same shape as k_cache
slot_mapping: [N] with values as flat indices, -1 means skip
"""
# Filter out invalid slots (slot == -1)
valid_mask = slot_mapping >= 0
if not valid_mask.any():
return
valid_slots = slot_mapping[valid_mask]
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
valid_values = value[valid_mask]
# Flatten cache and KV for scatter operation
# Cache is viewed as [total_slots, D] where D = num_kv_heads * head_dim
N, num_kv_heads, head_dim = key.shape
D = num_kv_heads * head_dim
total_slots = k_cache.numel() // D
k_cache_flat = k_cache.view(total_slots, D)
v_cache_flat = v_cache.view(total_slots, D)
valid_keys_flat = valid_keys.reshape(-1, D)
valid_values_flat = valid_values.reshape(-1, D)
# In-place scatter using index_copy_
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
class Attention(nn.Module):
@@ -66,8 +76,30 @@ class Attention(nn.Module):
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
# Determine if we're in chunked offload mode
is_chunked_offload = (
context.is_chunked_prefill and
hasattr(context, 'kvcache_manager') and
context.kvcache_manager is not None and
hasattr(context.kvcache_manager, 'offload_engine')
)
#! Ensure synchronization before accessing k_cache/v_cache
torch.cuda.synchronize()
#! =======================================================
if is_chunked_offload:
# Chunked offload mode: use compute_stream for store_kvcache
# This ensures proper synchronization with per-layer offload
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
else:
# Normal mode: store on default stream
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.is_chunked_prefill:
@@ -182,31 +214,48 @@ class Attention(nn.Module):
current_chunk_idx
)
# Get compute stream for all attention operations
compute_stream = None
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
compute_stream = kvcache_manager.offload_engine.compute_stream
# Compute attention against current chunk's KV (with causal mask)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Use compute_stream to ensure proper sync with store_kvcache and offload
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Merge with accumulated
# Merge with accumulated (all on compute_stream for consistency)
if o_acc is None:
final_o = current_o
else:
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
# reading it on the default stream for the merge operation.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
if compute_stream is not None:
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
else:
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
@@ -222,6 +271,16 @@ class Attention(nn.Module):
cpu_block_id = cpu_block_ids[current_chunk_idx]
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
# CRITICAL: compute_stream must wait for offload to complete
# before the next layer's store_kvcache can overwrite the GPU slot.
# Without this, Layer N+1's store races with Layer N's offload copy.
compute_stream.wait_event(offload_engine.ring_slot_offload_done[write_slot])
# Sync default stream with compute_stream before returning
# This ensures the result is ready for the rest of the model (layernorm, MLP)
if compute_stream is not None:
torch.cuda.default_stream().wait_stream(compute_stream)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -318,6 +377,7 @@ class Attention(nn.Module):
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
@@ -364,6 +424,7 @@ class Attention(nn.Module):
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
@@ -426,13 +487,15 @@ class Attention(nn.Module):
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last block
# Note: For chunked prefill, each block is exactly block_size tokens
# The cpu_block_table only contains full prefill blocks
# Calculate valid tokens in the last CPU block
# CRITICAL: Use original prefill length, not current seq length!
# CPU blocks are fixed after prefill, their content doesn't change during decode.
block_size = kvcache_manager.block_size
num_prefill_blocks = len(cpu_block_table)
# All prefill blocks are full (block_size tokens each)
last_block_valid_tokens = block_size
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
last_block_valid_tokens = total_prefill_tokens % block_size
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
last_block_valid_tokens = block_size # Last block was exactly full
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None:

757
tests/modeling_qwen3.py Normal file
View File

@@ -0,0 +1,757 @@
"""
Custom Qwen3 implementation using only torch and transformers.
This file provides a clean reference implementation for understanding the model computation graph.
Computation Graph:
==================
Input: token_ids [batch, seq_len]
┌─────────────┐
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
└─────────────┘
hidden_states [batch, seq_len, hidden_size]
┌─────────────────────────────────────────────────────────┐
│ Decoder Layer (x N) │
│ ┌───────────────────────────────────────────────────┐ │
│ │ Self Attention Block │ │
│ │ │ │
│ │ input_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3Attention │ │ │
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
│ │ │ V = v_proj(x) → reshape │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ attn_output = attention(Q, K, V) │ │ │
│ │ │ │ │ │ │
│ │ │ ▼ │ │ │
│ │ │ output = o_proj(attn_output) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + attn_output │ │
│ └───────────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ MLP Block │ │
│ │ │ │
│ │ post_attention_layernorm (RMSNorm) │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ ┌─────────────────────────────────────────────┐ │ │
│ │ │ Qwen3MLP │ │ │
│ │ │ gate = gate_proj(x) │ │ │
│ │ │ up = up_proj(x) │ │ │
│ │ │ output = down_proj(silu(gate) * up) │ │ │
│ │ └─────────────────────────────────────────────┘ │ │
│ │ │ │ │
│ │ ▼ │ │
│ │ hidden_states = residual + mlp_output │ │
│ └───────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
┌─────────────┐
│ norm │ final RMSNorm
└─────────────┘
┌─────────────┐
│ lm_head │ [hidden_size, vocab_size]
└─────────────┘
logits [batch, seq_len, vocab_size]
"""
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
class Qwen3RMSNorm(nn.Module):
"""RMSNorm implementation."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x.to(input_dtype)
class Qwen3RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)."""
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Compute inverse frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
@torch.no_grad()
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
position_ids: Position indices [batch, seq_len]
Returns:
cos, sin: [batch, seq_len, head_dim]
"""
# inv_freq: [dim/2]
# position_ids: [batch, seq_len]
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
# freqs: [batch, dim/2, seq_len]
freqs = inv_freq_expanded @ position_ids_expanded
# freqs: [batch, seq_len, dim/2]
freqs = freqs.transpose(1, 2)
# Duplicate for full head_dim: [batch, seq_len, dim]
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos().to(x.dtype)
sin = emb.sin().to(x.dtype)
return cos, sin
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embeddings to Q and K.
Args:
q: [batch, num_heads, seq_len, head_dim]
k: [batch, num_kv_heads, seq_len, head_dim]
cos: [batch, seq_len, head_dim]
sin: [batch, seq_len, head_dim]
Returns:
q_embed, k_embed with same shapes as inputs
"""
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Qwen3Attention(nn.Module):
"""
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
apply_rotary_pos_emb(Q, K)
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(
self,
hidden_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
attention_bias: bool = False,
rms_norm_eps: float = 1e-6,
layer_idx: int = 0,
):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_attention_heads
self.num_kv_heads = num_key_value_heads
self.head_dim = head_dim
self.num_kv_groups = num_attention_heads // num_key_value_heads
self.layer_idx = layer_idx
# Scaling factor
self.scaling = head_dim ** -0.5
# QKV projections
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
# QK normalization (Qwen3 specific)
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
# Rotary embeddings
self.rotary_emb = Qwen3RotaryEmbedding(
head_dim,
max_position_embeddings=max_position_embeddings,
base=rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
past_key_value: (k_cache, v_cache) from previous steps
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V tensors for debugging
Returns:
output: [batch, seq_len, hidden_size]
past_key_value: Updated cache (if use_cache=True)
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
"""
batch_size, seq_len, _ = hidden_states.shape
# === QKV Projections ===
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
# Reshape to [batch, seq_len, num_heads, head_dim]
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
# === QK Normalization (Qwen3 specific) ===
q = self.q_norm(q)
k = self.k_norm(k)
# Transpose to [batch, num_heads, seq_len, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# === Rotary Position Embeddings ===
cos, sin = self.rotary_emb(v, position_ids)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# === KV Cache Update ===
if past_key_value is not None:
k_cache, v_cache = past_key_value
k = torch.cat([k_cache, k], dim=2)
v = torch.cat([v_cache, v], dim=2)
new_past_key_value = (k, v) if use_cache else None
# === Grouped Query Attention (expand KV heads if needed) ===
if self.num_kv_groups > 1:
# Repeat KV for each query group
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
# === Attention Computation (using SDPA for memory efficiency) ===
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
# is_causal only works when q_len == kv_len (prefill), not during decode
q_len, kv_len = q.shape[2], k.shape[2]
is_causal = (q_len == kv_len) and (q_len > 1)
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=0.0,
is_causal=is_causal,
scale=self.scaling,
) # [batch, num_heads, seq_len, head_dim]
# === Output Projection ===
# Transpose back and reshape
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
output = self.o_proj(attn_output)
# Optional QKV output for debugging
qkv_dict = None
if output_qkv:
qkv_dict = {
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
}
return output, new_past_key_value, qkv_dict
class Qwen3MLP(nn.Module):
"""
Qwen3 MLP with SwiGLU activation.
Data Flow:
---------
hidden_states [batch, seq_len, hidden_size]
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
└──► up_proj ──► up [batch, seq_len, intermediate_size]
silu(gate) * up
down_proj ──► output [batch, seq_len, hidden_size]
"""
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate = self.gate_proj(x)
up = self.up_proj(x)
return self.down_proj(F.silu(gate) * up)
class Qwen3DecoderLayer(nn.Module):
"""Single Qwen3 Decoder Layer."""
def __init__(
self,
hidden_size: int,
intermediate_size: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
layer_idx: int = 0,
):
super().__init__()
self.layer_idx = layer_idx
# Pre-attention LayerNorm
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# Self-attention
self.self_attn = Qwen3Attention(
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
attention_bias=attention_bias,
rms_norm_eps=rms_norm_eps,
layer_idx=layer_idx,
)
# Post-attention LayerNorm
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
# MLP
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
output_qkv: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
"""
Args:
hidden_states: [batch, seq_len, hidden_size]
position_ids: [batch, seq_len]
attention_mask: Causal attention mask
past_key_value: KV cache for this layer
use_cache: Whether to return updated cache
output_qkv: Whether to output Q, K, V for debugging
Returns:
hidden_states: [batch, seq_len, hidden_size]
past_key_value: Updated cache
qkv_dict: QKV tensors (if output_qkv=True)
"""
# === Self Attention Block ===
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output, new_past_key_value, qkv_dict = self.self_attn(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
output_qkv=output_qkv,
)
hidden_states = residual + attn_output
# === MLP Block ===
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, new_past_key_value, qkv_dict
class Qwen3Model(nn.Module):
"""Qwen3 Transformer Model (without LM head)."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
):
super().__init__()
self.vocab_size = vocab_size
self.num_hidden_layers = num_hidden_layers
# Token embeddings
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
# Decoder layers
self.layers = nn.ModuleList([
Qwen3DecoderLayer(
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
layer_idx=i,
)
for i in range(num_hidden_layers)
])
# Final LayerNorm
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
position_ids: [batch, seq_len]
attention_mask: [batch, seq_len] or pre-computed 4D mask
past_key_values: List of (k, v) tuples for each layer
use_cache: Whether to return new cache
output_qkv_layers: List of layer indices to output QKV for
Returns:
hidden_states: [batch, seq_len, hidden_size]
new_past_key_values: Updated cache
qkv_outputs: {layer_idx: qkv_dict}
"""
batch_size, seq_len = input_ids.shape
# Embedding
hidden_states = self.embed_tokens(input_ids)
# Position IDs
if position_ids is None:
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
# Attention mask (create causal mask if not provided)
if attention_mask is None or attention_mask.dim() == 2:
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
causal_mask = torch.triu(
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
diagonal=kv_seq_len - seq_len + 1,
)
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
# Initialize cache list
new_past_key_values = [] if use_cache else None
qkv_outputs = {} if output_qkv_layers else None
# Decoder layers
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values else None
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
hidden_states, new_kv, qkv_dict = layer(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_value=past_kv,
use_cache=use_cache,
output_qkv=output_qkv,
)
if use_cache:
new_past_key_values.append(new_kv)
if qkv_dict is not None:
qkv_outputs[i] = qkv_dict
# Final norm
hidden_states = self.norm(hidden_states)
return hidden_states, new_past_key_values, qkv_outputs
class Qwen3ForCausalLM(nn.Module):
"""Qwen3 Model with Language Modeling head."""
def __init__(
self,
vocab_size: int,
hidden_size: int,
intermediate_size: int,
num_hidden_layers: int,
num_attention_heads: int,
num_key_value_heads: int,
head_dim: int,
max_position_embeddings: int = 32768,
rope_theta: float = 10000.0,
rms_norm_eps: float = 1e-6,
attention_bias: bool = False,
mlp_bias: bool = False,
tie_word_embeddings: bool = True,
):
super().__init__()
self.vocab_size = vocab_size
self.tie_word_embeddings = tie_word_embeddings
# Transformer model
self.model = Qwen3Model(
vocab_size=vocab_size,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
num_key_value_heads=num_key_value_heads,
head_dim=head_dim,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
mlp_bias=mlp_bias,
)
# LM head
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
output_qkv_layers: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
"""
Args:
input_ids: [batch, seq_len]
... (same as Qwen3Model)
Returns:
logits: [batch, seq_len, vocab_size]
past_key_values: Updated KV cache
qkv_outputs: QKV tensors for specified layers
"""
hidden_states, new_past_key_values, qkv_outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_qkv_layers=output_qkv_layers,
)
logits = self.lm_head(hidden_states)
return logits, new_past_key_values, qkv_outputs
@classmethod
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
"""
Load weights from a pretrained Qwen3 model.
Args:
model_path: Path to model directory containing config.json and model weights
dtype: Data type for model weights
Returns:
Initialized Qwen3ForCausalLM model
"""
import json
import os
from safetensors.torch import load_file
# Load config
config_path = os.path.join(model_path, "config.json")
with open(config_path) as f:
config = json.load(f)
# Create model
model = cls(
vocab_size=config["vocab_size"],
hidden_size=config["hidden_size"],
intermediate_size=config["intermediate_size"],
num_hidden_layers=config["num_hidden_layers"],
num_attention_heads=config["num_attention_heads"],
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
max_position_embeddings=config.get("max_position_embeddings", 32768),
rope_theta=config.get("rope_theta", 10000.0),
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
attention_bias=config.get("attention_bias", False),
mlp_bias=config.get("mlp_bias", False),
tie_word_embeddings=config.get("tie_word_embeddings", True),
)
# Load weights
weight_files = sorted([
f for f in os.listdir(model_path)
if f.endswith(".safetensors")
])
state_dict = {}
for wf in weight_files:
state_dict.update(load_file(os.path.join(model_path, wf)))
# Load into model
model.load_state_dict(state_dict, strict=False)
# Tie lm_head weights to embed_tokens if configured
if model.tie_word_embeddings:
model.lm_head.weight = model.model.embed_tokens.weight
model = model.to(dtype)
return model
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 32,
temperature: float = 1.0,
do_sample: bool = True,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[int] = None,
) -> torch.Tensor:
"""Simple autoregressive generation."""
device = input_ids.device
batch_size, seq_len = input_ids.shape
past_key_values = None
generated = input_ids.clone()
for _ in range(max_new_tokens):
if past_key_values is None:
current_input = generated
else:
current_input = generated[:, -1:]
logits, past_key_values, _ = self(
input_ids=current_input,
past_key_values=past_key_values,
use_cache=True,
)
next_token_logits = logits[:, -1, :]
if temperature > 0 and do_sample:
next_token_logits = next_token_logits / temperature
probs = torch.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if eos_token_id is not None and (next_token == eos_token_id).all():
break
return generated
def print_computation_graph():
"""Print the computation graph for reference."""
print(__doc__)
if __name__ == "__main__":
print_computation_graph()

365
tests/test_align.py Normal file
View File

@@ -0,0 +1,365 @@
"""
Test alignment between nanovllm and custom torch Qwen3 implementation.
Compares attention layer outputs and QKV tensors to verify correctness.
Usage:
python test_align.py # Without CPU offload
python test_align.py --enable-offload # With CPU offload
python test_align.py --input-len 4096 # Custom input length
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import argparse
import torch
from transformers import AutoTokenizer
from nanovllm import LLM, SamplingParams
from modeling_qwen3 import Qwen3ForCausalLM
from utils import generate_needle_prompt
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--enable-offload", action="store_true", help="Enable CPU offload")
parser.add_argument("--input-len", type=int, default=1024 * 12, help="Input sequence length")
parser.add_argument("--model-path", type=str, default="~/models/Qwen3-0.6B/", help="Model path")
parser.add_argument("--num-gpu-blocks", type=int, default=6, help="Number of GPU blocks (ring buffer slots)")
parser.add_argument("--block-size", type=int, default=1024, help="KV cache block size")
args = parser.parse_args()
# Config
MODEL_PATH = os.path.expanduser(args.model_path)
INPUT_LEN = args.input_len
ENABLE_OFFLOAD = args.enable_offload
NUM_GPU_BLOCKS = args.num_gpu_blocks
BLOCK_SIZE = args.block_size
DTYPE = torch.float16
print(f"Config: input_len={INPUT_LEN}, enable_offload={ENABLE_OFFLOAD}, num_gpu_blocks={NUM_GPU_BLOCKS}, block_size={BLOCK_SIZE}")
# Storage for captured tensors
nanovllm_outputs = {}
torch_outputs = {}
nanovllm_qkv = {}
nanovllm_proj_inputs = {}
torch_proj_inputs = {}
# ============================================================
# Hook functions for non-offload mode (overwrite)
# ============================================================
def make_nanovllm_hook(layer_id: int, storage: dict):
def hook(module, inputs, output):
attn_output = output[0] if isinstance(output, tuple) else output
if attn_output.dim() == 2:
attn_output = attn_output.unsqueeze(0)
storage[layer_id] = attn_output.detach().clone()
return hook
def make_nanovllm_qkv_hook(layer_id: int, storage: dict):
def hook(module, inputs):
q, k, v = inputs[0], inputs[1], inputs[2]
storage[layer_id] = {
"q": q.detach().clone(),
"k": k.detach().clone(),
"v": v.detach().clone(),
}
return hook
def make_proj_input_hook(layer_id: int, storage: dict):
def hook(module, inputs):
hidden = inputs[0]
if hidden.dim() == 2:
hidden = hidden.unsqueeze(0)
storage[layer_id] = hidden.detach().clone()
return hook
# ============================================================
# Hook functions for offload mode (accumulate Q and I, overwrite O)
# ============================================================
def make_accumulating_q_hook(layer_id: int, storage: dict):
"""Accumulate Q from each chunk for offload mode."""
def hook(module, inputs):
q = inputs[0].detach().clone()
if layer_id not in storage:
storage[layer_id] = []
storage[layer_id].append(q)
return hook
def make_accumulating_input_hook(layer_id: int, storage: dict):
"""Accumulate input hidden states from each chunk for offload mode."""
def hook(module, inputs):
hidden = inputs[0].detach().clone()
if layer_id not in storage:
storage[layer_id] = []
storage[layer_id].append(hidden)
return hook
def make_overwrite_output_hook(layer_id: int, storage: dict):
"""Overwrite output (keep only last chunk) for offload mode."""
def hook(module, inputs, output):
attn_output = output[0] if isinstance(output, tuple) else output
if attn_output.dim() == 2:
attn_output = attn_output.unsqueeze(0)
storage[layer_id] = attn_output.detach().clone()
return hook
# ============================================================
# CPU KV cache access for offload mode
# ============================================================
def get_nanovllm_kv_from_cpu(llm, seq, num_layers):
"""Get complete K, V cache from CPU side after all chunks finish."""
offload_engine = llm.model_runner.kvcache_manager.offload_engine
kvcache_manager = llm.model_runner.kvcache_manager
# CRITICAL: Synchronize all CUDA operations before reading CPU memory
# The D2H copy runs on transfer_stream_main and may still be in progress
torch.cuda.synchronize()
cpu_block_ids = kvcache_manager.get_cpu_block_table(seq)
kv_per_layer = {}
for layer_id in range(num_layers):
k_blocks = []
v_blocks = []
for cpu_block_id in cpu_block_ids:
k_block, v_block = offload_engine.get_cpu_block(layer_id, cpu_block_id)
k_blocks.append(k_block)
v_blocks.append(v_block)
# Concatenate all blocks: [total_tokens, kv_heads, head_dim]
k_full = torch.cat(k_blocks, dim=0)[:seq.num_tokens]
v_full = torch.cat(v_blocks, dim=0)[:seq.num_tokens]
kv_per_layer[layer_id] = {"k": k_full, "v": v_full}
return kv_per_layer
def make_torch_hook(layer_id: int, storage: dict):
def hook(module, inputs, output):
storage[layer_id] = output[0].detach().clone()
return hook
def cosine_sim(t1: torch.Tensor, t2: torch.Tensor) -> float:
"""Cosine similarity between flattened tensors (1.0 = identical)."""
return torch.nn.functional.cosine_similarity(
t1.flatten().float(), t2.flatten().float(), dim=0
).item()
def compute_qkv_sims(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int):
"""Compute Q, K, V cosine similarities. Returns (q_sim, k_sim, v_sim)."""
nano_q = nano_qkv["q"]
torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1)
nano_k = nano_qkv["k"]
torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
nano_v = nano_qkv["v"]
torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
return cosine_sim(nano_q, torch_q), cosine_sim(nano_k, torch_k), cosine_sim(nano_v, torch_v)
# ============================================================
# Load models
# ============================================================
print("Loading nanovllm model...")
llm_kwargs = dict(
enforce_eager=True,
max_model_len=32768,
gpu_memory_utilization=0.2,
max_num_batched_tokens=32768,
enable_cpu_offload=ENABLE_OFFLOAD,
dtype="float16",
kvcache_block_size=BLOCK_SIZE,
)
if ENABLE_OFFLOAD:
llm_kwargs["num_gpu_blocks"] = NUM_GPU_BLOCKS
llm = LLM(MODEL_PATH, **llm_kwargs)
num_heads = llm.model_runner.model.model.layers[0].self_attn.num_heads
num_kv_heads = llm.model_runner.model.model.layers[0].self_attn.num_kv_heads
num_kv_groups = num_heads // num_kv_heads
num_layers = len(llm.model_runner.model.model.layers)
print("Loading torch model...")
torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
torch_model = torch_model.to("cuda")
torch_model.eval()
# ============================================================
# Generate test input
# ============================================================
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
prompt, _ = generate_needle_prompt(tokenizer=tokenizer, target_length=INPUT_LEN, verbose=True)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
print(f"Input shape: {input_ids.shape}")
# ============================================================
# Register hooks
# ============================================================
nanovllm_hooks = []
nanovllm_q_accum = {} # For offload mode: accumulated Q from all chunks
nanovllm_i_accum = {} # For offload mode: accumulated I from all chunks
for layer_idx, layer in enumerate(llm.model_runner.model.model.layers):
if ENABLE_OFFLOAD:
# Offload mode: accumulate Q and I, overwrite O
nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_overwrite_output_hook(layer_idx, nanovllm_outputs)))
nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_accumulating_q_hook(layer_idx, nanovllm_q_accum)))
nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_accumulating_input_hook(layer_idx, nanovllm_i_accum)))
else:
# Non-offload mode: overwrite all
nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_nanovllm_hook(layer_idx, nanovllm_outputs)))
nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_nanovllm_qkv_hook(layer_idx, nanovllm_qkv)))
nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, nanovllm_proj_inputs)))
torch_hooks = []
for layer_idx, layer in enumerate(torch_model.model.layers):
torch_hooks.append(layer.self_attn.register_forward_hook(make_torch_hook(layer_idx, torch_outputs)))
torch_hooks.append(layer.self_attn.q_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, torch_proj_inputs)))
# ============================================================
# Run inference
# ============================================================
print("Running nanovllm inference...")
if ENABLE_OFFLOAD:
# Manual execution to capture KV cache before deallocation
# Use max_tokens=2 so sequence doesn't finish immediately after prefill
llm.add_request(input_ids[0].tolist(), SamplingParams(temperature=0.01, max_tokens=2))
# Run prefill step (this calls run_chunked_offload_prefill internally)
output, num_tokens = llm.step()
print(f"[Offload] Prefill done: {num_tokens} tokens")
# Now seq is in running queue, KV cache is in CPU
seq = llm.scheduler.running[0]
print(f"[Offload] Sequence: {seq}")
# Get KV cache from CPU BEFORE decode step deallocates it
nanovllm_kv_cpu = get_nanovllm_kv_from_cpu(llm, seq, num_layers)
print(f"[Offload] Retrieved KV cache from CPU for {seq.num_tokens} tokens")
# IMPORTANT: Save outputs NOW before decode step overwrites them
# nanovllm_outputs contains prefill attention outputs at this point
nanovllm_outputs_prefill = {k: v.clone() for k, v in nanovllm_outputs.items()}
# Complete remaining steps (decode)
while not llm.is_finished():
llm.step()
# Use prefill outputs for comparison
nanovllm_outputs = nanovllm_outputs_prefill
else:
nanovllm_result = llm.generate([input_ids[0].tolist()], SamplingParams(temperature=0.01, max_tokens=1), use_tqdm=False)
print("Running torch inference...")
with torch.no_grad():
torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers)))
# ============================================================
# Compare using cosine similarity (1.0 = perfect alignment)
# ============================================================
print("\n" + "=" * 70)
print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}")
print("=" * 70)
all_passed = True
threshold = 0.999 # Cosine similarity threshold
for layer_idx in range(num_layers):
if ENABLE_OFFLOAD:
# ============================================================
# Offload mode: use accumulated Q/I and CPU-side K/V
# Only compare prompt tokens (INPUT_LEN), exclude generated tokens
# ============================================================
# I: concatenate accumulated chunks, trim to prompt length
i_chunks = nanovllm_i_accum[layer_idx]
nano_in = torch.cat(i_chunks, dim=0)[:INPUT_LEN]
if nano_in.dim() == 2:
nano_in = nano_in.unsqueeze(0)
torch_in = torch_proj_inputs[layer_idx]
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
torch_in = torch_in.view(nano_in.shape)
i_sim = cosine_sim(nano_in, torch_in)
# Q: concatenate accumulated chunks, trim to prompt length
q_chunks = nanovllm_q_accum[layer_idx]
nano_q = torch.cat(q_chunks, dim=0)[:INPUT_LEN]
torch_q = torch_qkv_outputs[layer_idx]["q"].squeeze(0).transpose(0, 1)
q_sim = cosine_sim(nano_q, torch_q)
# K, V: from CPU cache, trim to prompt length and move to GPU
nano_k = nanovllm_kv_cpu[layer_idx]["k"][:INPUT_LEN].cuda()
torch_k = torch_qkv_outputs[layer_idx]["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
k_sim = cosine_sim(nano_k, torch_k)
nano_v = nanovllm_kv_cpu[layer_idx]["v"][:INPUT_LEN].cuda()
torch_v = torch_qkv_outputs[layer_idx]["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
v_sim = cosine_sim(nano_v, torch_v)
# O: compare attention outputs directly
# For single-chunk case (input_len <= block_size), shapes should match
# For multi-chunk case, nano_out is the last chunk only
nano_out = nanovllm_outputs[layer_idx]
torch_out = torch_outputs[layer_idx]
if nano_out.numel() == torch_out.numel():
# Single chunk or shapes match - compare directly
o_sim = cosine_sim(nano_out, torch_out)
else:
# Multi-chunk case: compare last chunk with corresponding torch slice
last_chunk_len = nano_out.shape[1] if nano_out.dim() == 3 else nano_out.shape[0]
torch_out_slice = torch_out[:, -last_chunk_len:, :] if torch_out.dim() == 3 else torch_out[-last_chunk_len:, :]
o_sim = cosine_sim(nano_out, torch_out_slice)
else:
# ============================================================
# Non-offload mode: original logic
# ============================================================
# Input similarity
nano_in = nanovllm_proj_inputs[layer_idx]
torch_in = torch_proj_inputs[layer_idx]
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
torch_in = torch_in.view(nano_in.shape)
i_sim = cosine_sim(nano_in, torch_in)
# QKV similarities
q_sim, k_sim, v_sim = compute_qkv_sims(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups)
# O similarity
nano_out = nanovllm_outputs[layer_idx]
torch_out = torch_outputs[layer_idx]
if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel():
torch_out = torch_out.view(nano_out.shape)
o_sim = cosine_sim(nano_out, torch_out)
# Check pass/fail
passed = all(s >= threshold for s in [i_sim, q_sim, k_sim, v_sim, o_sim])
all_passed = all_passed and passed
status = "" if passed else " *"
print(f"Layer {layer_idx:2d}{status:<3} {i_sim:>10.6f} {q_sim:>10.6f} {k_sim:>10.6f} {v_sim:>10.6f} {o_sim:>10.6f}")
# ============================================================
# Cleanup and result
# ============================================================
for hook in nanovllm_hooks + torch_hooks:
hook.remove()
print("=" * 70)
mode_str = " [offload]" if ENABLE_OFFLOAD else ""
if all_passed:
print(f"test_align{mode_str}: PASSED (cosine_sim >= 0.999)")
else:
print(f"test_align{mode_str}: FAILED (* = cosine_sim < 0.999)")

View File

@@ -0,0 +1,134 @@
"""
Test NanovllmSteppable: Print activation statistics at each layer.
Usage:
python tests/test_nanovllm_steppable.py
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import torch
from transformers import AutoTokenizer
from nanovllm import LLM
from nanovllm.debug.adapters.nanovllm_adapter import NanovllmSteppable
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Config
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
INPUT_LEN = 32768 # Longer context to test offload
MAX_NEW_TOKENS = 20
DTYPE = torch.float16
ENABLE_CPU_OFFLOAD = True # Test offload mode
# ============================================================
# Load Model
# ============================================================
print(f"Loading nanovllm model (cpu_offload={ENABLE_CPU_OFFLOAD})...")
llm = LLM(
MODEL_PATH,
enforce_eager=True, # Required for hooks to work
max_model_len=40960,
max_num_batched_tokens=40960,
enable_cpu_offload=ENABLE_CPU_OFFLOAD,
dtype="float16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# Get the underlying model for steppable
model = llm.model_runner.model
# ============================================================
# Prepare Input (using needle-in-haystack prompt)
# ============================================================
prompt, expected_answer = generate_needle_prompt(
tokenizer,
target_length=INPUT_LEN,
needle_position=0.5,
needle_value="7492",
use_chat_template=False,
verbose=True,
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
print(f"Input shape: {input_ids.shape}")
print(f"Expected answer: {expected_answer}\n")
# ============================================================
# Create Steppable Model (reused for prefill + decode)
# ============================================================
steppable = NanovllmSteppable(model)
# ============================================================
# Prefill Phase: Print activation stats
# ============================================================
print("=" * 85)
print("PREFILL PHASE")
print("=" * 85)
print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
print("-" * 85)
current_ids = input_ids.clone()
logits = None
for bp in steppable.step(current_ids, is_prefill=True):
t = bp.tensor.float()
shape_str = str(list(t.shape))
print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}")
if bp.name == "LM Head":
logits = bp.tensor
# Get first token from prefill
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
generated_tokens = [next_token]
# ============================================================
# Decode Phase: Only print generated tokens
# ============================================================
print("\n" + "=" * 85)
print("DECODE PHASE")
print("=" * 85)
print(f"Step 1: {next_token!r}")
for step in range(2, MAX_NEW_TOKENS + 1):
# Forward pass with full sequence (reuse same steppable)
# Note: nanovllm without KV cache needs full sequence for each decode
for bp in steppable.step(current_ids, is_prefill=True):
if bp.name == "LM Head":
logits = bp.tensor
# Get next token (greedy)
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
generated_tokens.append(next_token)
print(f"Step {step:2}: {next_token!r}")
# Stop if EOS
if next_token_id == tokenizer.eos_token_id:
print(" (EOS)")
break
# Append to sequence
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
# ============================================================
# Result
# ============================================================
print("\n" + "=" * 85)
print("RESULT")
print("=" * 85)
generated_text = "".join(generated_tokens)
print(f"Generated: {generated_text!r}")
print(f"Expected: {expected_answer}")
print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}")
print("\ntest_nanovllm_steppable: PASSED")

View File

@@ -8,155 +8,11 @@ sequences longer than ~200 tokens. Use --no-offload for correctness testing.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
# ============================================================
# Needle Test Generator
# ============================================================
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
) -> tuple[str, str]:
"""
Generate a needle-in-haystack prompt of approximately target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# Haystack filler paragraphs (various topics to create realistic context)
haystack_paragraphs = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Question at the end
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
# Estimate tokens for fixed parts
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
# Buffer for chat template, special tokens, etc.
overhead_tokens = 100 if use_chat_template else 50
# Available tokens for haystack
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
if haystack_target_tokens < 100:
raise ValueError(f"target_length {target_length} is too short for needle test")
# Build haystack by repeating paragraphs
haystack_parts = []
current_tokens = 0
para_idx = 0
while current_tokens < haystack_target_tokens:
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
if current_tokens + para_tokens > haystack_target_tokens:
break
haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1
# Calculate needle insertion point
needle_idx = int(len(haystack_parts) * needle_position)
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
# Insert needle
haystack_parts.insert(needle_idx, needle)
# Assemble prompt
full_text = "".join(haystack_parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
# Use chat template for instruct models
# For Qwen3, add /no_think to disable thinking mode
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
messages = [
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
# Verify length
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
import re
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
# Also try to find it as a standalone number
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
@@ -168,6 +24,7 @@ def run_needle_test(
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
needle_position: float = 0.5,
needle_value: str = "7492",
max_new_tokens: int = 32,
@@ -182,6 +39,7 @@ def run_needle_test(
max_model_len: Maximum model context length
input_len: Target input sequence length
num_gpu_blocks: Number of GPU blocks for offload
block_size: KV cache block size
needle_position: Where to place needle (0.0-1.0)
needle_value: The secret value to find
max_new_tokens: Maximum tokens to generate
@@ -198,6 +56,7 @@ def run_needle_test(
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"Needle position: {needle_position:.0%}")
print(f"Needle value: {needle_value}")
print(f"CPU offload: {enable_cpu_offload}")
@@ -209,6 +68,7 @@ def run_needle_test(
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
@@ -263,7 +123,7 @@ if __name__ == "__main__":
parser.add_argument(
"--max-model-len",
type=int,
default=32 * 1024,
default=36 * 1024,
help="Maximum model context length"
)
parser.add_argument(
@@ -278,6 +138,12 @@ if __name__ == "__main__":
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--needle-position",
type=float,
@@ -308,6 +174,7 @@ if __name__ == "__main__":
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
needle_position=args.needle_position,
needle_value=args.needle_value,
max_new_tokens=args.max_new_tokens,

View File

@@ -8,148 +8,9 @@ Uses standard HuggingFace inference (no custom KV cache, no offload).
import os
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# ============================================================
# Needle Test Generator
# ============================================================
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
) -> tuple[str, str]:
"""
Generate a needle-in-haystack prompt of approximately target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# Haystack filler paragraphs (various topics to create realistic context)
haystack_paragraphs = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Estimate tokens for fixed parts
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
# Buffer for chat template, special tokens, etc.
overhead_tokens = 100 if use_chat_template else 50
# Available tokens for haystack
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
if haystack_target_tokens < 100:
raise ValueError(f"target_length {target_length} is too short for needle test")
# Build haystack by repeating paragraphs
haystack_parts = []
current_tokens = 0
para_idx = 0
while current_tokens < haystack_target_tokens:
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
if current_tokens + para_tokens > haystack_target_tokens:
break
haystack_parts.append(para)
current_tokens += para_tokens
para_idx += 1
# Calculate needle insertion point
needle_idx = int(len(haystack_parts) * needle_position)
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
# Insert needle
haystack_parts.insert(needle_idx, needle)
# Assemble prompt
full_text = "".join(haystack_parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
# Use chat template for instruct models
# For Qwen3, add /no_think to disable thinking mode
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
messages = [
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
else:
# Raw text format for base models
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
prompt = full_text + question
# Verify length
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
import re
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
from transformers import AutoTokenizer
from modeling_qwen3 import Qwen3ForCausalLM
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
@@ -207,22 +68,19 @@ def run_needle_test(
# 3. Load model
print("[3/4] Loading model...")
torch_dtype = {
"auto": "auto",
"auto": torch.float16, # default to float16 for custom model
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}.get(dtype, "auto")
}.get(dtype, torch.float16)
model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
device_map="auto",
trust_remote_code=True,
)
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
# 4. Generate output
print("[4/4] Running inference...")
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
device = next(model.parameters()).device
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
print(f" Input shape: {input_ids.shape}")
with torch.no_grad():

View File

@@ -0,0 +1,121 @@
"""
Test TorchSteppable: Print activation statistics at each layer.
Usage:
python tests/test_torch_steppable.py
"""
import os
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import torch
from transformers import AutoTokenizer
from modeling_qwen3 import Qwen3ForCausalLM
from nanovllm.debug.adapters.torch_adapter import TorchSteppable
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Config
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
INPUT_LEN = 512
MAX_NEW_TOKENS = 20
DTYPE = torch.float16
# ============================================================
# Load Model
# ============================================================
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
model = model.to("cuda").eval()
# ============================================================
# Prepare Input (using needle-in-haystack prompt)
# ============================================================
prompt, expected_answer = generate_needle_prompt(
tokenizer,
target_length=INPUT_LEN,
needle_position=0.5,
needle_value="7492",
use_chat_template=False,
verbose=True,
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
print(f"Input shape: {input_ids.shape}")
print(f"Expected answer: {expected_answer}\n")
# ============================================================
# Create Steppable Model (reused for prefill + decode)
# ============================================================
steppable = TorchSteppable(model)
# ============================================================
# Prefill Phase: Print activation stats
# ============================================================
print("=" * 85)
print("PREFILL PHASE")
print("=" * 85)
print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
print("-" * 85)
current_ids = input_ids.clone()
logits = None
for bp in steppable.step(current_ids):
t = bp.tensor.float()
shape_str = str(list(t.shape))
print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}")
if bp.name == "LM Head":
logits = bp.tensor
# Get first token from prefill
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
generated_tokens = [next_token]
# ============================================================
# Decode Phase: Only print generated tokens
# ============================================================
print("\n" + "=" * 85)
print("DECODE PHASE")
print("=" * 85)
print(f"Step 1: {next_token!r}")
for step in range(2, MAX_NEW_TOKENS + 1):
# Forward pass (reuse same steppable)
for bp in steppable.step(current_ids):
if bp.name == "LM Head":
logits = bp.tensor
# Get next token (greedy)
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
generated_tokens.append(next_token)
print(f"Step {step:2}: {next_token!r}")
# Stop if EOS
if next_token_id == tokenizer.eos_token_id:
print(" (EOS)")
break
# Append to sequence
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
# ============================================================
# Result
# ============================================================
print("\n" + "=" * 85)
print("RESULT")
print("=" * 85)
generated_text = "".join(generated_tokens)
print(f"Generated: {generated_text!r}")
print(f"Expected: {expected_answer}")
print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}")
print("\ntest_torch_steppable: PASSED")

186
tests/utils.py Normal file
View File

@@ -0,0 +1,186 @@
"""
Test utilities for nano-vllm.
"""
import re
from typing import Tuple
# ============================================================
# Needle-in-Haystack Test Utilities
# ============================================================
# Haystack filler paragraphs (various topics to create realistic context)
HAYSTACK_PARAGRAPHS = [
"The weather today is quite pleasant with clear skies and moderate temperatures. "
"Many people are enjoying outdoor activities in the park. "
"Birds are singing in the trees and children are playing on the swings. ",
"In the world of technology, new innovations continue to emerge every day. "
"Researchers are working on advanced algorithms and computing systems. "
"The future of artificial intelligence looks promising with many breakthroughs. ",
"The history of human civilization spans thousands of years. "
"Ancient cultures developed writing, mathematics, and astronomy. "
"Trade routes connected distant lands and facilitated cultural exchange. ",
"Modern cooking combines traditional techniques with new ingredients. "
"Chefs around the world experiment with flavors and presentations. "
"Food brings people together and creates memorable experiences. ",
"The ocean covers more than seventy percent of Earth's surface. "
"Marine ecosystems support an incredible diversity of life forms. "
"Scientists continue to discover new species in the deep sea. ",
"Music has been a part of human culture since prehistoric times. "
"Different genres evolved across various regions and time periods. "
"Today, people can access millions of songs through digital platforms. ",
"Space exploration has revealed many secrets about our universe. "
"Telescopes can observe galaxies billions of light years away. "
"Future missions aim to establish human presence on other planets. ",
"The study of languages reveals patterns in human cognition. "
"Linguists analyze grammar, semantics, and phonetics across cultures. "
"Language continues to evolve with new words and expressions. ",
]
def generate_needle_prompt(
tokenizer,
target_length: int,
needle_position: float = 0.5,
needle_value: str = "7492",
use_chat_template: bool = True,
verbose: bool = True,
) -> Tuple[str, str]:
"""
Generate a needle-in-haystack prompt of exactly target_length tokens.
Args:
tokenizer: HuggingFace tokenizer for length estimation
target_length: Target total sequence length in tokens
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
needle_value: The secret value to hide in the haystack
use_chat_template: Whether to use chat template for instruct models
verbose: Whether to print generation info
Returns:
(prompt, expected_answer): The full prompt and the expected needle value
"""
# The needle sentence
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
# Question text
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
else:
question_text = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
def build_prompt(haystack_parts, needle_idx):
"""Build full prompt from haystack parts with needle inserted."""
parts = haystack_parts.copy()
parts.insert(needle_idx, needle)
full_text = "".join(parts)
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
messages = [{"role": "user", "content": f"{full_text}\n\n{question_text}"}]
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
return full_text + question_text
def count_tokens(prompt):
return len(tokenizer.encode(prompt, add_special_tokens=False))
def get_needle_idx(parts):
idx = int(len(parts) * needle_position)
return max(0, min(idx, len(parts)))
# Phase 1: Build haystack with full paragraphs until we exceed target
haystack_parts = []
para_idx = 0
while True:
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
test_parts = haystack_parts + [para]
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
if count_tokens(prompt) > target_length:
break
haystack_parts.append(para)
para_idx += 1
if para_idx > 10000: # Safety limit
break
# Phase 2: Fine-tune by adding words from next paragraph
next_para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
words = next_para.split()
best_parts = haystack_parts.copy()
best_diff = abs(target_length - count_tokens(build_prompt(haystack_parts, get_needle_idx(haystack_parts))))
for i in range(1, len(words) + 1):
partial = " ".join(words[:i]) + " "
test_parts = haystack_parts + [partial]
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
token_count = count_tokens(prompt)
diff = abs(target_length - token_count)
if diff < best_diff:
best_diff = diff
best_parts = test_parts.copy()
if token_count >= target_length:
break
haystack_parts = best_parts
# Final build
needle_idx = get_needle_idx(haystack_parts)
prompt = build_prompt(haystack_parts, needle_idx)
actual_tokens = count_tokens(prompt)
if verbose:
print(f"[NeedleTest] Target: {target_length}, Actual: {actual_tokens} tokens (diff={actual_tokens - target_length})")
return prompt, needle_value
def check_needle_answer(output_text: str, expected: str) -> bool:
"""Check if the model output contains the expected needle value."""
# Clean output - remove special tokens and whitespace
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
output_clean = ' '.join(output_clean.split()).lower()
expected_clean = expected.strip().lower()
# Check if expected value appears in output
# Also try to find it as a standalone number
if expected_clean in output_clean:
return True
# Try to extract numbers and check if expected is among them
numbers = re.findall(r'\d+', output_clean)
return expected_clean in numbers
def generate_random_token_ids(
length: int,
vocab_size: int = 10000,
seed: int = 42,
) -> list:
"""
Generate random token IDs for testing.
Args:
length: Number of tokens to generate
vocab_size: Maximum token ID (exclusive)
seed: Random seed for reproducibility
Returns:
List of random token IDs
"""
from random import randint, seed as set_seed
set_seed(seed)
return [randint(0, vocab_size - 1) for _ in range(length)]