Compare commits
11 Commits
ff8b09cd35
...
tzj/minfer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
247c5312d9 | ||
|
|
054aaff403 | ||
|
|
d623043a3c | ||
|
|
e897380127 | ||
|
|
24096431ed | ||
|
|
772313db8f | ||
|
|
00ed17c640 | ||
|
|
9b52d25866 | ||
|
|
8c3418725b | ||
|
|
b3685c9190 | ||
|
|
6927a75ac3 |
76
CLAUDE.md
76
CLAUDE.md
@@ -20,6 +20,80 @@ For sparse attention related content (block sparse attention, MInference, FlexPr
|
||||
- **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096
|
||||
- **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload
|
||||
|
||||
## PyTorch Hooks for Debugging
|
||||
|
||||
### Hook Positions in Qwen3
|
||||
|
||||
```
|
||||
decoder_layer
|
||||
├── input_layernorm (RMSNorm)
|
||||
├── self_attn (Qwen3Attention) ← Hook here for attention I/O after o_proj
|
||||
│ ├── q_proj → q_norm → RoPE
|
||||
│ ├── k_proj → k_norm → RoPE
|
||||
│ ├── v_proj
|
||||
│ ├── attn (Attention) ← Hook here for Q/K/V tensors
|
||||
│ │ └── FlashAttention / SDPA
|
||||
│ └── o_proj
|
||||
├── post_attention_layernorm (RMSNorm)
|
||||
└── mlp (Qwen3MLP)
|
||||
```
|
||||
|
||||
### Hook Types & Data Shapes
|
||||
|
||||
| Hook Position | Type | Captured Data |
|
||||
|---------------|------|---------------|
|
||||
| `self_attn` | post | `[batch, seq_len, hidden_size]` - after o_proj |
|
||||
| `self_attn.attn` | pre | Q,K,V: `[seq_len, num_heads, head_dim]` - after RoPE |
|
||||
| `self_attn.attn` | post | `[seq_len, num_heads, head_dim]` - before o_proj |
|
||||
|
||||
### Example: Capture Attention Outputs
|
||||
|
||||
```python
|
||||
storage = {}
|
||||
|
||||
def make_hook(layer_id: int, storage: dict):
|
||||
def hook(module, inputs, output):
|
||||
if isinstance(output, tuple):
|
||||
attn_output = output[0]
|
||||
else:
|
||||
attn_output = output
|
||||
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
|
||||
if attn_output.dim() == 2:
|
||||
attn_output = attn_output.unsqueeze(0)
|
||||
storage[layer_id] = attn_output.detach().clone()
|
||||
return hook
|
||||
|
||||
# Register hooks
|
||||
hooks = []
|
||||
for layer_idx, layer in enumerate(model.model.layers):
|
||||
hooks.append(layer.self_attn.register_forward_hook(make_hook(layer_idx, storage)))
|
||||
|
||||
# Run inference...
|
||||
|
||||
# Cleanup
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
```
|
||||
|
||||
### Alignment Testing
|
||||
|
||||
Use `tests/test_align.py` to compare nanovllm with reference torch implementation:
|
||||
|
||||
```bash
|
||||
python tests/test_align.py
|
||||
```
|
||||
|
||||
Key files:
|
||||
- `tests/modeling_qwen3.py`: Reference Qwen3 implementation (torch + transformers only)
|
||||
- `tests/test_align.py`: Compares attention outputs between nanovllm and reference
|
||||
- `tests/test_needle_ref.py`: Reference needle test using custom Qwen3
|
||||
|
||||
### Common Pitfalls
|
||||
|
||||
1. **Shape mismatch**: nanovllm uses `[num_tokens, ...]` while torch uses `[batch, seq_len, ...]`
|
||||
2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj
|
||||
3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]`
|
||||
|
||||
## CPU Offload System
|
||||
|
||||
### Ring Buffer Design
|
||||
@@ -228,7 +302,7 @@ def _merge_output_kernel(...):
|
||||
|
||||
| Parameter | Default | Notes |
|
||||
|-----------|---------|-------|
|
||||
| `kvcache_block_size` | 4096 | Tokens per block |
|
||||
| `kvcache_block_size` | 1024 | Tokens per block |
|
||||
| `max_num_batched_tokens` | 16384 | Set = max_model_len for long context |
|
||||
| `gpu_memory_utilization` | 0.9 | GPU memory fraction |
|
||||
| `enable_cpu_offload` | False | Enable for long context |
|
||||
|
||||
103
DEBUG_SUMMARY.md
Normal file
103
DEBUG_SUMMARY.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# Chunked Prefill Bug Debug Summary
|
||||
|
||||
## Problem
|
||||
`test_needle.py --enable-offload --input-len 8192` fails with garbage output.
|
||||
|
||||
The model generates completely wrong tokens instead of the expected "7492".
|
||||
|
||||
## Investigation Progress
|
||||
|
||||
### 1. Stream Synchronization Fix (Completed)
|
||||
- Replaced Triton `store_kvcache` kernel with pure PyTorch operations
|
||||
- Moved `store_kvcache` to `compute_stream` in chunked prefill mode
|
||||
- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload
|
||||
- Added sync: `default_stream.wait_stream(compute_stream)` before return
|
||||
|
||||
### 2. KV Cache Alignment Verification (Completed)
|
||||
Created alignment tests to compare K/V tensors between torch reference and nanovllm:
|
||||
|
||||
**RoPE Alignment:**
|
||||
- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0)
|
||||
- Confirmed RoPE is NOT the cause of the bug
|
||||
|
||||
**K/V Cache Alignment (Chunk 0):**
|
||||
- Cosine similarity: ~1.0 for all layers
|
||||
- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision)
|
||||
- Mean diff: < 0.001
|
||||
- **Conclusion: K/V cache offload is working correctly**
|
||||
|
||||
### 3. Layer Output Divergence Analysis (Completed)
|
||||
Created per-chunk layer output comparison:
|
||||
|
||||
**Chunk 0 (tokens 0-4096):**
|
||||
- All layers pass with excellent cosine similarity (0.999+)
|
||||
- Max diff grows in later layers but within acceptable range
|
||||
|
||||
**Chunk 1 (tokens 4096-8192):**
|
||||
- Layers 0-19: OK (cosine ~1.0)
|
||||
- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114)
|
||||
- Divergence correlates with later transformer layers
|
||||
|
||||
### 4. Critical Discovery: Single-Chunk Offload Also Fails
|
||||
**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled.
|
||||
|
||||
```
|
||||
# Without offload: PASSES
|
||||
python tests/test_needle.py --input-len 2048
|
||||
# Output: "7492" (correct)
|
||||
|
||||
# With offload: FAILS
|
||||
python tests/test_needle.py --enable-offload --input-len 2048
|
||||
# Output: "The Ble White Th G Lopsiswin..." (garbage)
|
||||
```
|
||||
|
||||
**This proves the bug is NOT in:**
|
||||
- Chunked attention logic (merge_attention_outputs)
|
||||
- Multi-chunk KV loading
|
||||
- Ring buffer pipeline
|
||||
|
||||
**The bug IS in:**
|
||||
- The decode path when CPU offload is enabled
|
||||
- How prefilled KV is loaded/used during decode
|
||||
|
||||
### 5. Decode Path Analysis (In Progress)
|
||||
The decode path in CPU offload mode:
|
||||
1. Prefill writes KV to GPU, offloads to CPU
|
||||
2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline`
|
||||
3. Attend to prefilled KV + accumulated decode tokens
|
||||
4. Merge results
|
||||
|
||||
**Observations:**
|
||||
- `prefilled_blocks` set is empty after decode (should contain block IDs)
|
||||
- CPU cache has valid data (reasonable mean/std values)
|
||||
- Decode buffer has zeros (decode tokens not being stored correctly?)
|
||||
|
||||
## Current Status
|
||||
|
||||
### Working
|
||||
- Stream synchronization fixes
|
||||
- K/V cache offload to CPU (verified alignment)
|
||||
- RoPE implementation
|
||||
- Chunked prefill attention for first chunk
|
||||
|
||||
### Not Working
|
||||
- Decode with CPU offload (even for single-chunk inputs)
|
||||
- Multi-chunk attention (divergence in later layers for chunk 1)
|
||||
|
||||
## Next Steps
|
||||
1. Debug why `prefilled_blocks` is empty after decode
|
||||
2. Check if decode path correctly loads KV from CPU
|
||||
3. Verify decode buffer is being written correctly
|
||||
4. Compare decode attention outputs between offload and non-offload modes
|
||||
|
||||
## Key Files
|
||||
- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths
|
||||
- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine
|
||||
- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks`
|
||||
- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration
|
||||
|
||||
## Hypothesis
|
||||
The decode path fails because:
|
||||
1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty
|
||||
2. OR the decode attention is not correctly loading/using the prefilled KV from CPU
|
||||
3. OR there's a stream synchronization issue specific to decode path
|
||||
@@ -15,7 +15,7 @@ class Config:
|
||||
enforce_eager: bool = False
|
||||
hf_config: AutoConfig | None = None
|
||||
eos: int = -1
|
||||
kvcache_block_size: int = 4096
|
||||
kvcache_block_size: int = 1024
|
||||
num_kvcache_blocks: int = -1
|
||||
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
|
||||
|
||||
|
||||
49
nanovllm/debug/__init__.py
Normal file
49
nanovllm/debug/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Breakpoint debugging tools for aligning nanovllm with reference implementations.
|
||||
|
||||
This module provides a generator-based breakpoint aligner that enables step-by-step
|
||||
comparison between nanovllm and torch reference model outputs.
|
||||
|
||||
Example:
|
||||
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
|
||||
>>> from tests.modeling_qwen3 import Qwen3ForCausalLM
|
||||
>>>
|
||||
>>> # Load models
|
||||
>>> torch_model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch.float16)
|
||||
>>> nanovllm_model = ... # Your nanovllm model
|
||||
>>>
|
||||
>>> # Create adapters
|
||||
>>> ref = TorchSteppable(torch_model)
|
||||
>>> test = NanovllmSteppable(nanovllm_model)
|
||||
>>>
|
||||
>>> # Run alignment
|
||||
>>> aligner = BreakpointAligner(ref, test)
|
||||
>>> result = aligner.align(input_ids)
|
||||
>>> print(result)
|
||||
"""
|
||||
|
||||
from .breakpoints import BreakpointType, Breakpoint
|
||||
from .comparator import TensorComparator, ComparisonResult
|
||||
from .aligner import BreakpointAligner, AlignmentResult
|
||||
from .adapters import SteppableModel, TorchSteppable, NanovllmSteppable
|
||||
from .utils import setup_prefill_context, setup_decode_context, cleanup_context
|
||||
|
||||
__all__ = [
|
||||
# Core classes
|
||||
"BreakpointAligner",
|
||||
"AlignmentResult",
|
||||
# Breakpoints
|
||||
"BreakpointType",
|
||||
"Breakpoint",
|
||||
# Comparator
|
||||
"TensorComparator",
|
||||
"ComparisonResult",
|
||||
# Adapters
|
||||
"SteppableModel",
|
||||
"TorchSteppable",
|
||||
"NanovllmSteppable",
|
||||
# Utils
|
||||
"setup_prefill_context",
|
||||
"setup_decode_context",
|
||||
"cleanup_context",
|
||||
]
|
||||
11
nanovllm/debug/adapters/__init__.py
Normal file
11
nanovllm/debug/adapters/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""Model adapters for breakpoint alignment."""
|
||||
|
||||
from .base import SteppableModel
|
||||
from .torch_adapter import TorchSteppable
|
||||
from .nanovllm_adapter import NanovllmSteppable
|
||||
|
||||
__all__ = [
|
||||
"SteppableModel",
|
||||
"TorchSteppable",
|
||||
"NanovllmSteppable",
|
||||
]
|
||||
59
nanovllm/debug/adapters/base.py
Normal file
59
nanovllm/debug/adapters/base.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Base class for steppable model adapters."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generator, Set, Optional
|
||||
import torch
|
||||
|
||||
from ..breakpoints import Breakpoint, BreakpointType
|
||||
|
||||
|
||||
class SteppableModel(ABC):
|
||||
"""
|
||||
Abstract base class for models that can yield at breakpoints.
|
||||
|
||||
Subclasses implement the step() method as a generator that yields
|
||||
Breakpoint objects at each enabled breakpoint during forward pass.
|
||||
"""
|
||||
|
||||
def __init__(self, enabled_breakpoints: Optional[Set[BreakpointType]] = None):
|
||||
"""
|
||||
Args:
|
||||
enabled_breakpoints: Set of breakpoint types to yield at.
|
||||
If None, yields at all breakpoints.
|
||||
"""
|
||||
self.enabled_breakpoints = enabled_breakpoints
|
||||
|
||||
def is_enabled(self, bp_type: BreakpointType) -> bool:
|
||||
"""Check if a breakpoint type is enabled."""
|
||||
if self.enabled_breakpoints is None:
|
||||
return True
|
||||
return bp_type in self.enabled_breakpoints
|
||||
|
||||
@abstractmethod
|
||||
def step(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
) -> Generator[Breakpoint, None, torch.Tensor]:
|
||||
"""
|
||||
Generator that yields Breakpoint objects at enabled breakpoints.
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
positions: Position IDs (optional, auto-generated if None)
|
||||
is_prefill: True for prefill phase, False for decode
|
||||
|
||||
Yields:
|
||||
Breakpoint objects at each enabled checkpoint
|
||||
|
||||
Returns:
|
||||
Final output tensor (logits)
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def num_layers(self) -> int:
|
||||
"""Return the number of decoder layers."""
|
||||
pass
|
||||
235
nanovllm/debug/adapters/nanovllm_adapter.py
Normal file
235
nanovllm/debug/adapters/nanovllm_adapter.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Nanovllm model adapter for breakpoint alignment."""
|
||||
|
||||
from typing import Generator, Set, Optional, Dict, Any, List
|
||||
import torch
|
||||
|
||||
from nanovllm.utils.context import set_context, reset_context
|
||||
from ..breakpoints import Breakpoint, BreakpointType
|
||||
from .base import SteppableModel
|
||||
|
||||
|
||||
class NanovllmSteppable(SteppableModel):
|
||||
"""
|
||||
Steppable adapter for nanovllm Qwen3 implementation.
|
||||
|
||||
Uses PyTorch hooks to capture intermediate values during forward pass,
|
||||
then yields them as breakpoints after execution completes.
|
||||
|
||||
Key challenges handled:
|
||||
1. Shape difference: nanovllm uses [num_tokens, hidden] vs [batch, seq, hidden]
|
||||
2. Context-based attention: must call set_context() before forward
|
||||
3. Fused operations: decoder layer returns (hidden_states, residual) tuple
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Qwen3ForCausalLM from nanovllm
|
||||
enabled_breakpoints: Set of breakpoint types to yield at
|
||||
"""
|
||||
super().__init__(enabled_breakpoints)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
self._hooks: List[Any] = []
|
||||
self._captured: Dict[str, torch.Tensor] = {}
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return len(self.model.model.layers)
|
||||
|
||||
def _register_hooks(self):
|
||||
"""Register forward hooks on all relevant modules."""
|
||||
self._hooks = []
|
||||
self._captured = {}
|
||||
|
||||
# Hook for embedding output
|
||||
def embed_hook(module, input, output):
|
||||
self._captured["embed"] = output.detach().clone()
|
||||
|
||||
self._hooks.append(
|
||||
self.model.model.embed_tokens.register_forward_hook(embed_hook)
|
||||
)
|
||||
|
||||
# Hooks for each decoder layer
|
||||
for layer_idx in range(self.num_layers):
|
||||
layer = self.model.model.layers[layer_idx]
|
||||
|
||||
def make_layer_hook(idx):
|
||||
def hook(module, input, output):
|
||||
# Decoder layer returns (hidden_states, residual)
|
||||
# hidden_states is MLP output, residual is accumulated residual
|
||||
# To match torch reference, we need hidden_states + residual
|
||||
if isinstance(output, tuple) and len(output) >= 2:
|
||||
hidden_states, residual = output[0], output[1]
|
||||
full_output = hidden_states + residual
|
||||
else:
|
||||
full_output = output
|
||||
self._captured[f"layer_{idx}"] = full_output.detach().clone()
|
||||
return hook
|
||||
|
||||
self._hooks.append(
|
||||
layer.register_forward_hook(make_layer_hook(layer_idx))
|
||||
)
|
||||
|
||||
# Hook for final norm
|
||||
def final_norm_hook(module, input, output):
|
||||
# Final norm returns (hidden_states, _) for fused add
|
||||
hidden_states = output[0] if isinstance(output, tuple) else output
|
||||
self._captured["final_norm"] = hidden_states.detach().clone()
|
||||
|
||||
self._hooks.append(
|
||||
self.model.model.norm.register_forward_hook(final_norm_hook)
|
||||
)
|
||||
|
||||
# Hook for lm_head
|
||||
def lm_head_hook(module, input, output):
|
||||
self._captured["lm_head"] = output.detach().clone()
|
||||
|
||||
self._hooks.append(
|
||||
self.model.lm_head.register_forward_hook(lm_head_hook)
|
||||
)
|
||||
|
||||
def _remove_hooks(self):
|
||||
"""Remove all registered hooks."""
|
||||
for hook in self._hooks:
|
||||
hook.remove()
|
||||
self._hooks = []
|
||||
|
||||
def _setup_context(self, seq_len: int, device: torch.device, is_prefill: bool):
|
||||
"""
|
||||
Set up nanovllm context for forward pass.
|
||||
|
||||
For alignment testing, we use simple context without real KV cache.
|
||||
"""
|
||||
if is_prefill:
|
||||
# Prefill: process all tokens at once
|
||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
|
||||
# Use -1 for slot_mapping to skip KV cache writes
|
||||
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
|
||||
|
||||
set_context(
|
||||
is_prefill=True,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
slot_mapping=slot_mapping,
|
||||
is_chunked_prefill=False,
|
||||
)
|
||||
else:
|
||||
# Decode: single token generation
|
||||
# For decode, we need context_lens and block_tables
|
||||
# For alignment testing without real KV cache, we use minimal setup
|
||||
context_lens = torch.tensor([seq_len - 1], dtype=torch.int32, device=device)
|
||||
# Single token slot
|
||||
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
|
||||
# Empty block tables (no KV cache)
|
||||
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
|
||||
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
|
||||
def _normalize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Normalize nanovllm tensor shape to [batch, seq_len, ...].
|
||||
|
||||
nanovllm uses [num_tokens, ...] format without batch dimension.
|
||||
We add batch dimension for comparison with torch model.
|
||||
"""
|
||||
if tensor.dim() == 2: # [num_tokens, hidden_size]
|
||||
return tensor.unsqueeze(0)
|
||||
return tensor
|
||||
|
||||
def step(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
) -> Generator[Breakpoint, None, torch.Tensor]:
|
||||
"""
|
||||
Execute nanovllm forward pass with hooks to capture breakpoints.
|
||||
|
||||
Unlike the torch adapter which manually steps through each component,
|
||||
we run the full forward pass and collect captured values afterward.
|
||||
"""
|
||||
# Ensure 1D for nanovllm (it expects [num_tokens])
|
||||
if input_ids.dim() == 2:
|
||||
input_ids = input_ids.squeeze(0)
|
||||
|
||||
seq_len = input_ids.numel()
|
||||
device = input_ids.device
|
||||
|
||||
# Generate position IDs if not provided
|
||||
if positions is None:
|
||||
positions = torch.arange(seq_len, device=device)
|
||||
elif positions.dim() == 2:
|
||||
positions = positions.squeeze(0)
|
||||
|
||||
# Register hooks
|
||||
self._register_hooks()
|
||||
|
||||
try:
|
||||
# Setup context for attention
|
||||
self._setup_context(seq_len, device, is_prefill)
|
||||
|
||||
# Run forward pass (hooks capture everything)
|
||||
with torch.no_grad():
|
||||
hidden_states = self.model(input_ids, positions)
|
||||
logits = self.model.compute_logits(hidden_states)
|
||||
|
||||
reset_context()
|
||||
|
||||
# Yield breakpoints in order from captured data
|
||||
|
||||
# EMBEDDING
|
||||
if self.is_enabled(BreakpointType.EMBEDDING) and "embed" in self._captured:
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.EMBEDDING,
|
||||
layer_idx=None,
|
||||
tensor=self._normalize_tensor(self._captured["embed"]),
|
||||
name="Embedding",
|
||||
)
|
||||
|
||||
# LAYER_OUTPUT for each layer
|
||||
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
|
||||
for layer_idx in range(self.num_layers):
|
||||
key = f"layer_{layer_idx}"
|
||||
if key in self._captured:
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.LAYER_OUTPUT,
|
||||
layer_idx=layer_idx,
|
||||
tensor=self._normalize_tensor(self._captured[key]),
|
||||
name=f"Layer {layer_idx}",
|
||||
)
|
||||
|
||||
# FINAL_NORM
|
||||
if self.is_enabled(BreakpointType.FINAL_NORM) and "final_norm" in self._captured:
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.FINAL_NORM,
|
||||
layer_idx=None,
|
||||
tensor=self._normalize_tensor(self._captured["final_norm"]),
|
||||
name="Final Norm",
|
||||
)
|
||||
|
||||
# LM_HEAD
|
||||
if self.is_enabled(BreakpointType.LM_HEAD) and "lm_head" in self._captured:
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.LM_HEAD,
|
||||
layer_idx=None,
|
||||
tensor=self._normalize_tensor(self._captured["lm_head"]),
|
||||
name="LM Head",
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
finally:
|
||||
self._remove_hooks()
|
||||
self._captured = {}
|
||||
119
nanovllm/debug/adapters/torch_adapter.py
Normal file
119
nanovllm/debug/adapters/torch_adapter.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Torch reference model adapter for breakpoint alignment."""
|
||||
|
||||
from typing import Generator, Set, Optional
|
||||
import torch
|
||||
|
||||
from ..breakpoints import Breakpoint, BreakpointType
|
||||
from .base import SteppableModel
|
||||
|
||||
|
||||
class TorchSteppable(SteppableModel):
|
||||
"""
|
||||
Steppable adapter for the torch reference Qwen3 implementation.
|
||||
|
||||
Wraps tests/modeling_qwen3.py Qwen3ForCausalLM model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Qwen3ForCausalLM from tests/modeling_qwen3.py
|
||||
enabled_breakpoints: Set of breakpoint types to yield at
|
||||
"""
|
||||
super().__init__(enabled_breakpoints)
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
@property
|
||||
def num_layers(self) -> int:
|
||||
return len(self.model.model.layers)
|
||||
|
||||
def step(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
) -> Generator[Breakpoint, None, torch.Tensor]:
|
||||
"""
|
||||
Generator that manually steps through the torch model.
|
||||
|
||||
The torch model uses [batch, seq_len, hidden_size] shapes.
|
||||
"""
|
||||
# Ensure batch dimension
|
||||
if input_ids.dim() == 1:
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
|
||||
batch_size, seq_len = input_ids.shape
|
||||
device = input_ids.device
|
||||
|
||||
# Generate position IDs if not provided
|
||||
if positions is None:
|
||||
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
|
||||
elif positions.dim() == 1:
|
||||
positions = positions.unsqueeze(0)
|
||||
|
||||
with torch.no_grad():
|
||||
# === EMBEDDING ===
|
||||
hidden_states = self.model.model.embed_tokens(input_ids)
|
||||
|
||||
if self.is_enabled(BreakpointType.EMBEDDING):
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.EMBEDDING,
|
||||
layer_idx=None,
|
||||
tensor=hidden_states.detach().clone(),
|
||||
name="Embedding",
|
||||
)
|
||||
|
||||
# Create causal attention mask
|
||||
causal_mask = torch.triu(
|
||||
torch.full((seq_len, seq_len), float("-inf"), device=device),
|
||||
diagonal=1,
|
||||
)
|
||||
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
# === DECODER LAYERS ===
|
||||
for layer_idx, layer in enumerate(self.model.model.layers):
|
||||
hidden_states, _, _ = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=positions,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=None,
|
||||
use_cache=False,
|
||||
output_qkv=False,
|
||||
)
|
||||
|
||||
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.LAYER_OUTPUT,
|
||||
layer_idx=layer_idx,
|
||||
tensor=hidden_states.detach().clone(),
|
||||
name=f"Layer {layer_idx}",
|
||||
)
|
||||
|
||||
# === FINAL NORM ===
|
||||
hidden_states = self.model.model.norm(hidden_states)
|
||||
|
||||
if self.is_enabled(BreakpointType.FINAL_NORM):
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.FINAL_NORM,
|
||||
layer_idx=None,
|
||||
tensor=hidden_states.detach().clone(),
|
||||
name="Final Norm",
|
||||
)
|
||||
|
||||
# === LM HEAD ===
|
||||
logits = self.model.lm_head(hidden_states)
|
||||
|
||||
if self.is_enabled(BreakpointType.LM_HEAD):
|
||||
yield Breakpoint(
|
||||
bp_type=BreakpointType.LM_HEAD,
|
||||
layer_idx=None,
|
||||
tensor=logits.detach().clone(),
|
||||
name="LM Head",
|
||||
)
|
||||
|
||||
return logits
|
||||
211
nanovllm/debug/aligner.py
Normal file
211
nanovllm/debug/aligner.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Breakpoint aligner for comparing model outputs."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Tuple, Optional
|
||||
import torch
|
||||
|
||||
from .breakpoints import Breakpoint
|
||||
from .comparator import TensorComparator, ComparisonResult
|
||||
from .adapters.base import SteppableModel
|
||||
|
||||
|
||||
@dataclass
|
||||
class AlignmentResult:
|
||||
"""Result of an alignment test."""
|
||||
passed: bool
|
||||
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = field(default_factory=list)
|
||||
failed_at: Optional[Breakpoint] = None
|
||||
message: str = ""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
passed_count = sum(1 for _, _, c in self.all_comparisons if c.passed)
|
||||
total = len(self.all_comparisons)
|
||||
status = "PASSED" if self.passed else "FAILED"
|
||||
return f"AlignmentResult({status}, {passed_count}/{total} breakpoints passed)"
|
||||
|
||||
|
||||
class BreakpointAligner:
|
||||
"""
|
||||
Orchestrates alternating execution of reference and test models,
|
||||
comparing outputs at each breakpoint.
|
||||
|
||||
Example:
|
||||
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
|
||||
>>> ref = TorchSteppable(torch_model)
|
||||
>>> test = NanovllmSteppable(nanovllm_model)
|
||||
>>> aligner = BreakpointAligner(ref, test)
|
||||
>>> result = aligner.align(input_ids)
|
||||
>>> print(result)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ref_model: SteppableModel,
|
||||
test_model: SteppableModel,
|
||||
comparator: Optional[TensorComparator] = None,
|
||||
stop_on_error: bool = True,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
ref_model: Reference (torch) steppable model
|
||||
test_model: Test (nanovllm) steppable model
|
||||
comparator: Tensor comparator instance (uses default if None)
|
||||
stop_on_error: If True, stop at first mismatch
|
||||
verbose: If True, print comparison results
|
||||
"""
|
||||
self.ref_model = ref_model
|
||||
self.test_model = test_model
|
||||
self.comparator = comparator or TensorComparator()
|
||||
self.stop_on_error = stop_on_error
|
||||
self.verbose = verbose
|
||||
|
||||
def align(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
) -> AlignmentResult:
|
||||
"""
|
||||
Run both models with same input, comparing at each breakpoint.
|
||||
|
||||
Args:
|
||||
input_ids: Input token IDs
|
||||
positions: Position IDs (optional)
|
||||
is_prefill: True for prefill phase, False for decode
|
||||
|
||||
Returns:
|
||||
AlignmentResult with pass/fail status and details
|
||||
"""
|
||||
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = []
|
||||
|
||||
if self.verbose:
|
||||
phase = "prefill" if is_prefill else "decode"
|
||||
print(f"\n{'='*60}")
|
||||
print(f"Alignment Test ({phase})")
|
||||
print(f"{'='*60}")
|
||||
|
||||
# Start both generators
|
||||
ref_gen = self.ref_model.step(input_ids, positions, is_prefill)
|
||||
test_gen = self.test_model.step(input_ids, positions, is_prefill)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get next breakpoint from reference
|
||||
try:
|
||||
ref_bp = next(ref_gen)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
# Get corresponding breakpoint from test
|
||||
try:
|
||||
test_bp = next(test_gen)
|
||||
except StopIteration:
|
||||
if self.verbose:
|
||||
print(f"Test model ended early at {ref_bp.name}")
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
failed_at=ref_bp,
|
||||
message=f"Test model ended early at {ref_bp.name}",
|
||||
)
|
||||
|
||||
# Verify breakpoints match
|
||||
if ref_bp.bp_type != test_bp.bp_type:
|
||||
msg = f"Breakpoint type mismatch: {ref_bp.bp_type} vs {test_bp.bp_type}"
|
||||
if self.verbose:
|
||||
print(msg)
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
failed_at=ref_bp,
|
||||
message=msg,
|
||||
)
|
||||
|
||||
if ref_bp.layer_idx != test_bp.layer_idx:
|
||||
msg = f"Layer index mismatch: {ref_bp.layer_idx} vs {test_bp.layer_idx}"
|
||||
if self.verbose:
|
||||
print(msg)
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
failed_at=ref_bp,
|
||||
message=msg,
|
||||
)
|
||||
|
||||
# Normalize shapes for comparison
|
||||
ref_t = ref_bp.normalize_shape()
|
||||
test_t = test_bp.normalize_shape()
|
||||
|
||||
# Handle shape mismatches
|
||||
if ref_t.shape != test_t.shape:
|
||||
if self.verbose:
|
||||
print(f"[{ref_bp.name}] Shape mismatch: ref={ref_t.shape} vs test={test_t.shape}")
|
||||
|
||||
# Try to reshape if element count matches
|
||||
if ref_t.numel() == test_t.numel():
|
||||
test_t = test_t.view(ref_t.shape)
|
||||
else:
|
||||
msg = f"Shape mismatch at {ref_bp.name}: {ref_t.shape} vs {test_t.shape}"
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
failed_at=ref_bp,
|
||||
message=msg,
|
||||
)
|
||||
|
||||
# Compare tensors
|
||||
result = self.comparator.compare(ref_t, test_t, ref_bp.name)
|
||||
all_comparisons.append((ref_bp, test_bp, result))
|
||||
|
||||
if self.verbose:
|
||||
status = "\u2713" if result.passed else "\u2717"
|
||||
print(f"{status} [{ref_bp.name}] cos={result.cosine_similarity:.6f}, max_diff={result.max_abs_diff:.2e}")
|
||||
|
||||
if not result.passed and self.stop_on_error:
|
||||
if self.verbose:
|
||||
print(f"\nStopped at {ref_bp.name} (stop_on_error=True)")
|
||||
print(result.message)
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
failed_at=ref_bp,
|
||||
message=f"Alignment failed at {ref_bp.name}",
|
||||
)
|
||||
|
||||
# Check for extra test breakpoints
|
||||
try:
|
||||
extra_bp = next(test_gen)
|
||||
msg = f"Test model has extra breakpoints starting at {extra_bp.name}"
|
||||
if self.verbose:
|
||||
print(msg)
|
||||
return AlignmentResult(
|
||||
passed=False,
|
||||
all_comparisons=all_comparisons,
|
||||
message=msg,
|
||||
)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
msg = f"Exception during alignment: {str(e)}"
|
||||
if self.verbose:
|
||||
print(msg)
|
||||
raise
|
||||
|
||||
# Summary
|
||||
all_passed = all(comp[2].passed for comp in all_comparisons)
|
||||
passed_count = sum(1 for _, _, c in all_comparisons if c.passed)
|
||||
total = len(all_comparisons)
|
||||
|
||||
if self.verbose:
|
||||
print(f"{'='*60}")
|
||||
status = "PASSED" if all_passed else "FAILED"
|
||||
print(f"Result: {status} ({passed_count}/{total} breakpoints)")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
return AlignmentResult(
|
||||
passed=all_passed,
|
||||
all_comparisons=all_comparisons,
|
||||
message="All breakpoints aligned" if all_passed else "Some breakpoints failed",
|
||||
)
|
||||
39
nanovllm/debug/breakpoints.py
Normal file
39
nanovllm/debug/breakpoints.py
Normal file
@@ -0,0 +1,39 @@
|
||||
"""Breakpoint types and data structures for alignment debugging."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
class BreakpointType(Enum):
|
||||
"""Types of breakpoints in the model forward pass."""
|
||||
EMBEDDING = auto() # After embed_tokens
|
||||
LAYER_OUTPUT = auto() # After each decoder layer
|
||||
FINAL_NORM = auto() # After final RMSNorm
|
||||
LM_HEAD = auto() # After lm_head (logits)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Breakpoint:
|
||||
"""A captured breakpoint with tensor data."""
|
||||
bp_type: BreakpointType
|
||||
layer_idx: Optional[int] # None for EMBEDDING, FINAL_NORM, LM_HEAD
|
||||
tensor: torch.Tensor
|
||||
name: str
|
||||
|
||||
def normalize_shape(self) -> torch.Tensor:
|
||||
"""
|
||||
Normalize tensor shape for comparison.
|
||||
|
||||
nanovllm uses [num_tokens, hidden_size] while torch uses
|
||||
[batch, seq_len, hidden_size]. This adds a batch dimension
|
||||
to 2D tensors for comparison.
|
||||
"""
|
||||
if self.tensor.dim() == 2:
|
||||
return self.tensor.unsqueeze(0)
|
||||
return self.tensor
|
||||
|
||||
def __repr__(self) -> str:
|
||||
shape_str = "x".join(str(d) for d in self.tensor.shape)
|
||||
return f"Breakpoint({self.name}, shape={shape_str}, dtype={self.tensor.dtype})"
|
||||
94
nanovllm/debug/comparator.py
Normal file
94
nanovllm/debug/comparator.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""Tensor comparison utilities for alignment debugging."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComparisonResult:
|
||||
"""Result of comparing two tensors."""
|
||||
passed: bool
|
||||
cosine_similarity: float
|
||||
max_abs_diff: float
|
||||
mean_abs_diff: float
|
||||
message: str
|
||||
|
||||
def __repr__(self) -> str:
|
||||
status = "\u2713" if self.passed else "\u2717"
|
||||
return f"{status} cos={self.cosine_similarity:.6f}, max_diff={self.max_abs_diff:.2e}"
|
||||
|
||||
|
||||
class TensorComparator:
|
||||
"""Compares tensors using cosine similarity and absolute differences."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cosine_threshold: float = 0.999,
|
||||
max_diff_threshold: float = 0.1,
|
||||
mean_diff_threshold: float = 0.01,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
cosine_threshold: Minimum cosine similarity to pass (0-1)
|
||||
max_diff_threshold: Maximum allowed absolute difference
|
||||
mean_diff_threshold: Maximum allowed mean absolute difference
|
||||
"""
|
||||
self.cosine_threshold = cosine_threshold
|
||||
self.max_diff_threshold = max_diff_threshold
|
||||
self.mean_diff_threshold = mean_diff_threshold
|
||||
|
||||
def compare(
|
||||
self,
|
||||
ref: torch.Tensor,
|
||||
test: torch.Tensor,
|
||||
name: str = "",
|
||||
) -> ComparisonResult:
|
||||
"""
|
||||
Compare two tensors and return detailed result.
|
||||
|
||||
Args:
|
||||
ref: Reference tensor
|
||||
test: Test tensor
|
||||
name: Name for the comparison (used in message)
|
||||
|
||||
Returns:
|
||||
ComparisonResult with pass/fail status and metrics
|
||||
"""
|
||||
# Convert to float32 for comparison
|
||||
ref_f = ref.float().flatten()
|
||||
test_f = test.float().flatten()
|
||||
|
||||
# Cosine similarity
|
||||
cos_sim = F.cosine_similarity(
|
||||
ref_f.unsqueeze(0),
|
||||
test_f.unsqueeze(0)
|
||||
).item()
|
||||
|
||||
# Absolute differences
|
||||
diff = (ref.float() - test.float()).abs()
|
||||
max_diff = diff.max().item()
|
||||
mean_diff = diff.mean().item()
|
||||
|
||||
# Check thresholds
|
||||
passed = (
|
||||
cos_sim >= self.cosine_threshold and
|
||||
max_diff <= self.max_diff_threshold and
|
||||
mean_diff <= self.mean_diff_threshold
|
||||
)
|
||||
|
||||
status = "PASS" if passed else "FAIL"
|
||||
message = (
|
||||
f"[{name}] {status}\n"
|
||||
f" Cosine Similarity: {cos_sim:.6f} (threshold: {self.cosine_threshold})\n"
|
||||
f" Max Abs Diff: {max_diff:.6f} (threshold: {self.max_diff_threshold})\n"
|
||||
f" Mean Abs Diff: {mean_diff:.6f} (threshold: {self.mean_diff_threshold})"
|
||||
)
|
||||
|
||||
return ComparisonResult(
|
||||
passed=passed,
|
||||
cosine_similarity=cos_sim,
|
||||
max_abs_diff=max_diff,
|
||||
mean_abs_diff=mean_diff,
|
||||
message=message,
|
||||
)
|
||||
51
nanovllm/debug/utils.py
Normal file
51
nanovllm/debug/utils.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Utility functions for breakpoint alignment debugging."""
|
||||
|
||||
import torch
|
||||
from nanovllm.utils.context import set_context, reset_context
|
||||
|
||||
|
||||
def setup_prefill_context(seq_len: int, device: torch.device):
|
||||
"""
|
||||
Set up nanovllm context for prefill alignment testing.
|
||||
|
||||
Args:
|
||||
seq_len: Sequence length
|
||||
device: Target device
|
||||
"""
|
||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
|
||||
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
|
||||
|
||||
set_context(
|
||||
is_prefill=True,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
slot_mapping=slot_mapping,
|
||||
is_chunked_prefill=False,
|
||||
)
|
||||
|
||||
|
||||
def setup_decode_context(context_len: int, device: torch.device):
|
||||
"""
|
||||
Set up nanovllm context for decode alignment testing.
|
||||
|
||||
Args:
|
||||
context_len: Context length (number of previous tokens)
|
||||
device: Target device
|
||||
"""
|
||||
context_lens = torch.tensor([context_len], dtype=torch.int32, device=device)
|
||||
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
|
||||
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
|
||||
|
||||
set_context(
|
||||
is_prefill=False,
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
|
||||
|
||||
def cleanup_context():
|
||||
"""Reset nanovllm context after alignment testing."""
|
||||
reset_context()
|
||||
@@ -35,7 +35,10 @@ class ModelRunner:
|
||||
self.model = Qwen3ForCausalLM(hf_config)
|
||||
load_model(self.model, config.model)
|
||||
self.sampler = GreedySampler()
|
||||
self.warmup_model()
|
||||
|
||||
#> Disable warmup for debugging
|
||||
# self.warmup_model()
|
||||
|
||||
self.allocate_kv_cache()
|
||||
if not self.enforce_eager:
|
||||
self.capture_cudagraph()
|
||||
@@ -194,7 +197,7 @@ class ModelRunner:
|
||||
f"block_size={self.block_size}"
|
||||
)
|
||||
|
||||
# Bind layer caches to attention modules and set layer_id
|
||||
#> Bind layer caches to attention modules and set layer_id
|
||||
layer_id = 0
|
||||
for module in self.model.modules():
|
||||
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||
@@ -480,7 +483,7 @@ class ModelRunner:
|
||||
if input_ids.numel() == 0:
|
||||
break
|
||||
|
||||
# Run model forward
|
||||
#> Run model forward
|
||||
logits = self.run_model(input_ids, positions, is_prefill=True)
|
||||
reset_context()
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ class SequenceStatus(Enum):
|
||||
|
||||
|
||||
class Sequence:
|
||||
block_size = 4096
|
||||
block_size = 1024
|
||||
counter = count()
|
||||
|
||||
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
|
||||
@@ -34,6 +34,14 @@ class Sequence:
|
||||
def __getitem__(self, key):
|
||||
return self.token_ids[key]
|
||||
|
||||
def __repr__(self):
|
||||
ids = self.token_ids
|
||||
if len(ids) > 20:
|
||||
ids_str = "[" + ", ".join(map(str, ids[:10])) + ", ..., " + ", ".join(map(str, ids[-5:])) + "]"
|
||||
else:
|
||||
ids_str = str(ids)
|
||||
return f"Seq(id={self.seq_id}, status={self.status.name}, tokens={self.num_tokens}, ids={ids_str})"
|
||||
|
||||
@property
|
||||
def is_finished(self):
|
||||
return self.status == SequenceStatus.FINISHED
|
||||
|
||||
@@ -146,6 +146,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
# Key: sequence id, Value: starting position where decode began in current block
|
||||
self._decode_start_pos: Dict[int, int] = {}
|
||||
|
||||
# Track original prefill length (for correct last_block_valid_tokens calculation)
|
||||
# Key: sequence id, Value: number of tokens from prefill (before decode started)
|
||||
self._prefill_len: Dict[int, int] = {}
|
||||
|
||||
# Sparse attention policy (optional)
|
||||
self.sparse_policy: Optional["SparsePolicy"] = None
|
||||
|
||||
@@ -542,6 +546,26 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
seq_id = id(seq)
|
||||
self._decode_start_pos[seq_id] = 0
|
||||
|
||||
def get_prefill_len(self, seq: Sequence) -> int:
|
||||
"""
|
||||
Get the original prefill length for a sequence.
|
||||
|
||||
This is cached on first call to ensure correct last_block_valid_tokens
|
||||
calculation during decode (the CPU blocks don't change after prefill).
|
||||
|
||||
Args:
|
||||
seq: Sequence
|
||||
|
||||
Returns:
|
||||
Number of tokens from prefill (before decode started)
|
||||
"""
|
||||
seq_id = id(seq)
|
||||
if seq_id not in self._prefill_len:
|
||||
# First decode step - store the prefill length
|
||||
# len(seq) - 1 because current len includes the first decode token
|
||||
self._prefill_len[seq_id] = len(seq) - 1
|
||||
return self._prefill_len[seq_id]
|
||||
|
||||
def clear_decode_tracking(self, seq: Sequence) -> None:
|
||||
"""
|
||||
Clear decode position tracking for sequence.
|
||||
@@ -553,6 +577,7 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
"""
|
||||
seq_id = id(seq)
|
||||
self._decode_start_pos.pop(seq_id, None)
|
||||
self._prefill_len.pop(seq_id, None)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
|
||||
@@ -39,7 +39,7 @@ class PolicyContext:
|
||||
is_prefill: bool
|
||||
"""True if in prefill phase, False if in decode phase."""
|
||||
|
||||
block_size: int = 4096
|
||||
block_size: int = 1024
|
||||
"""Number of tokens per block."""
|
||||
|
||||
total_kv_len: int = 0
|
||||
|
||||
@@ -2,8 +2,6 @@ import logging
|
||||
import torch
|
||||
import torch.cuda.nvtx
|
||||
from torch import nn
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
from nanovllm.utils.context import get_context
|
||||
@@ -12,37 +10,49 @@ from nanovllm.kvcache.sparse.policy import PolicyContext
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def store_kvcache_kernel(
|
||||
key_ptr,
|
||||
key_stride,
|
||||
value_ptr,
|
||||
value_stride,
|
||||
k_cache_ptr,
|
||||
v_cache_ptr,
|
||||
slot_mapping_ptr,
|
||||
D: tl.constexpr,
|
||||
def store_kvcache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
):
|
||||
idx = tl.program_id(0)
|
||||
slot = tl.load(slot_mapping_ptr + idx)
|
||||
if slot == -1: return
|
||||
key_offsets = idx * key_stride + tl.arange(0, D)
|
||||
value_offsets = idx * value_stride + tl.arange(0, D)
|
||||
key = tl.load(key_ptr + key_offsets)
|
||||
value = tl.load(value_ptr + value_offsets)
|
||||
cache_offsets = slot * D + tl.arange(0, D)
|
||||
tl.store(k_cache_ptr + cache_offsets, key)
|
||||
tl.store(v_cache_ptr + cache_offsets, value)
|
||||
"""
|
||||
Store key/value tensors into KV cache using slot mapping.
|
||||
|
||||
This is a pure PyTorch implementation replacing the previous Triton kernel.
|
||||
Uses index_copy_ for efficient in-place scatter operation.
|
||||
|
||||
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
||||
N, num_heads, head_dim = key.shape
|
||||
D = num_heads * head_dim
|
||||
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
||||
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
||||
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
||||
assert slot_mapping.numel() == N
|
||||
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
||||
Args:
|
||||
key: [N, num_kv_heads, head_dim]
|
||||
value: [N, num_kv_heads, head_dim]
|
||||
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] or similar
|
||||
v_cache: same shape as k_cache
|
||||
slot_mapping: [N] with values as flat indices, -1 means skip
|
||||
"""
|
||||
# Filter out invalid slots (slot == -1)
|
||||
valid_mask = slot_mapping >= 0
|
||||
if not valid_mask.any():
|
||||
return
|
||||
|
||||
valid_slots = slot_mapping[valid_mask]
|
||||
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
|
||||
valid_values = value[valid_mask]
|
||||
|
||||
# Flatten cache and KV for scatter operation
|
||||
# Cache is viewed as [total_slots, D] where D = num_kv_heads * head_dim
|
||||
N, num_kv_heads, head_dim = key.shape
|
||||
D = num_kv_heads * head_dim
|
||||
total_slots = k_cache.numel() // D
|
||||
|
||||
k_cache_flat = k_cache.view(total_slots, D)
|
||||
v_cache_flat = v_cache.view(total_slots, D)
|
||||
valid_keys_flat = valid_keys.reshape(-1, D)
|
||||
valid_values_flat = valid_values.reshape(-1, D)
|
||||
|
||||
# In-place scatter using index_copy_
|
||||
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
|
||||
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@@ -66,6 +76,28 @@ class Attention(nn.Module):
|
||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||
context = get_context()
|
||||
k_cache, v_cache = self.k_cache, self.v_cache
|
||||
|
||||
# Determine if we're in chunked offload mode
|
||||
is_chunked_offload = (
|
||||
context.is_chunked_prefill and
|
||||
hasattr(context, 'kvcache_manager') and
|
||||
context.kvcache_manager is not None and
|
||||
hasattr(context.kvcache_manager, 'offload_engine')
|
||||
)
|
||||
|
||||
#! Ensure synchronization before accessing k_cache/v_cache
|
||||
torch.cuda.synchronize()
|
||||
#! =======================================================
|
||||
|
||||
if is_chunked_offload:
|
||||
# Chunked offload mode: use compute_stream for store_kvcache
|
||||
# This ensures proper synchronization with per-layer offload
|
||||
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
with torch.cuda.stream(compute_stream):
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
else:
|
||||
# Normal mode: store on default stream
|
||||
if k_cache.numel() and v_cache.numel():
|
||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||
|
||||
@@ -182,8 +214,25 @@ class Attention(nn.Module):
|
||||
current_chunk_idx
|
||||
)
|
||||
|
||||
# Get compute stream for all attention operations
|
||||
compute_stream = None
|
||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||
compute_stream = kvcache_manager.offload_engine.compute_stream
|
||||
|
||||
# Compute attention against current chunk's KV (with causal mask)
|
||||
# Use compute_stream to ensure proper sync with store_kvcache and offload
|
||||
if compute_stream is not None:
|
||||
with torch.cuda.stream(compute_stream):
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
k_batched,
|
||||
v_batched,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched,
|
||||
@@ -194,16 +243,16 @@ class Attention(nn.Module):
|
||||
)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
|
||||
# Merge with accumulated
|
||||
# Merge with accumulated (all on compute_stream for consistency)
|
||||
if o_acc is None:
|
||||
final_o = current_o
|
||||
else:
|
||||
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
|
||||
# reading it on the default stream for the merge operation.
|
||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
||||
offload_engine = kvcache_manager.offload_engine
|
||||
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
|
||||
|
||||
if compute_stream is not None:
|
||||
with torch.cuda.stream(compute_stream):
|
||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
else:
|
||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
@@ -222,6 +271,16 @@ class Attention(nn.Module):
|
||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
|
||||
|
||||
# CRITICAL: compute_stream must wait for offload to complete
|
||||
# before the next layer's store_kvcache can overwrite the GPU slot.
|
||||
# Without this, Layer N+1's store races with Layer N's offload copy.
|
||||
compute_stream.wait_event(offload_engine.ring_slot_offload_done[write_slot])
|
||||
|
||||
# Sync default stream with compute_stream before returning
|
||||
# This ensures the result is ready for the rest of the model (layernorm, MLP)
|
||||
if compute_stream is not None:
|
||||
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||
|
||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||
return final_o.squeeze(0)
|
||||
|
||||
@@ -318,6 +377,7 @@ class Attention(nn.Module):
|
||||
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
||||
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
||||
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
@@ -364,6 +424,7 @@ class Attention(nn.Module):
|
||||
|
||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=self.scale,
|
||||
@@ -426,13 +487,15 @@ class Attention(nn.Module):
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||
|
||||
# Calculate valid tokens in the last block
|
||||
# Note: For chunked prefill, each block is exactly block_size tokens
|
||||
# The cpu_block_table only contains full prefill blocks
|
||||
# Calculate valid tokens in the last CPU block
|
||||
# CRITICAL: Use original prefill length, not current seq length!
|
||||
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||
block_size = kvcache_manager.block_size
|
||||
num_prefill_blocks = len(cpu_block_table)
|
||||
# All prefill blocks are full (block_size tokens each)
|
||||
last_block_valid_tokens = block_size
|
||||
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||
last_block_valid_tokens = block_size # Last block was exactly full
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
if kvcache_manager.sparse_policy is not None:
|
||||
|
||||
757
tests/modeling_qwen3.py
Normal file
757
tests/modeling_qwen3.py
Normal file
@@ -0,0 +1,757 @@
|
||||
"""
|
||||
Custom Qwen3 implementation using only torch and transformers.
|
||||
This file provides a clean reference implementation for understanding the model computation graph.
|
||||
|
||||
Computation Graph:
|
||||
==================
|
||||
|
||||
Input: token_ids [batch, seq_len]
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ Decoder Layer (x N) │
|
||||
│ ┌───────────────────────────────────────────────────┐ │
|
||||
│ │ Self Attention Block │ │
|
||||
│ │ │ │
|
||||
│ │ input_layernorm (RMSNorm) │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
||||
│ │ │ Qwen3Attention │ │ │
|
||||
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
|
||||
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
|
||||
│ │ │ V = v_proj(x) → reshape │ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ attn_output = attention(Q, K, V) │ │ │
|
||||
│ │ │ │ │ │ │
|
||||
│ │ │ ▼ │ │ │
|
||||
│ │ │ output = o_proj(attn_output) │ │ │
|
||||
│ │ └─────────────────────────────────────────────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ hidden_states = residual + attn_output │ │
|
||||
│ └───────────────────────────────────────────────────┘ │
|
||||
│ │ │
|
||||
│ ▼ │
|
||||
│ ┌───────────────────────────────────────────────────┐ │
|
||||
│ │ MLP Block │ │
|
||||
│ │ │ │
|
||||
│ │ post_attention_layernorm (RMSNorm) │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ ┌─────────────────────────────────────────────┐ │ │
|
||||
│ │ │ Qwen3MLP │ │ │
|
||||
│ │ │ gate = gate_proj(x) │ │ │
|
||||
│ │ │ up = up_proj(x) │ │ │
|
||||
│ │ │ output = down_proj(silu(gate) * up) │ │ │
|
||||
│ │ └─────────────────────────────────────────────┘ │ │
|
||||
│ │ │ │ │
|
||||
│ │ ▼ │ │
|
||||
│ │ hidden_states = residual + mlp_output │ │
|
||||
│ └───────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ norm │ final RMSNorm
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
┌─────────────┐
|
||||
│ lm_head │ [hidden_size, vocab_size]
|
||||
└─────────────┘
|
||||
│
|
||||
▼
|
||||
logits [batch, seq_len, vocab_size]
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple, List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Qwen3RMSNorm(nn.Module):
|
||||
"""RMSNorm implementation."""
|
||||
|
||||
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
input_dtype = x.dtype
|
||||
x = x.float()
|
||||
variance = x.pow(2).mean(-1, keepdim=True)
|
||||
x = x * torch.rsqrt(variance + self.eps)
|
||||
return self.weight * x.to(input_dtype)
|
||||
|
||||
|
||||
class Qwen3RotaryEmbedding(nn.Module):
|
||||
"""Rotary Position Embedding (RoPE)."""
|
||||
|
||||
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
|
||||
# Compute inverse frequencies
|
||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
|
||||
position_ids: Position indices [batch, seq_len]
|
||||
|
||||
Returns:
|
||||
cos, sin: [batch, seq_len, head_dim]
|
||||
"""
|
||||
# inv_freq: [dim/2]
|
||||
# position_ids: [batch, seq_len]
|
||||
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
|
||||
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
|
||||
|
||||
# freqs: [batch, dim/2, seq_len]
|
||||
freqs = inv_freq_expanded @ position_ids_expanded
|
||||
# freqs: [batch, seq_len, dim/2]
|
||||
freqs = freqs.transpose(1, 2)
|
||||
|
||||
# Duplicate for full head_dim: [batch, seq_len, dim]
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
|
||||
cos = emb.cos().to(x.dtype)
|
||||
sin = emb.sin().to(x.dtype)
|
||||
|
||||
return cos, sin
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Rotate half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Apply rotary position embeddings to Q and K.
|
||||
|
||||
Args:
|
||||
q: [batch, num_heads, seq_len, head_dim]
|
||||
k: [batch, num_kv_heads, seq_len, head_dim]
|
||||
cos: [batch, seq_len, head_dim]
|
||||
sin: [batch, seq_len, head_dim]
|
||||
|
||||
Returns:
|
||||
q_embed, k_embed with same shapes as inputs
|
||||
"""
|
||||
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
|
||||
cos = cos.unsqueeze(1)
|
||||
sin = sin.unsqueeze(1)
|
||||
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Qwen3Attention(nn.Module):
|
||||
"""
|
||||
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
|
||||
|
||||
Data Flow:
|
||||
---------
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
|
||||
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
|
||||
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
|
||||
│
|
||||
▼
|
||||
apply_rotary_pos_emb(Q, K)
|
||||
│
|
||||
▼
|
||||
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
|
||||
│
|
||||
▼
|
||||
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
attention_bias: bool = False,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
layer_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.num_kv_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.num_kv_groups = num_attention_heads // num_key_value_heads
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Scaling factor
|
||||
self.scaling = head_dim ** -0.5
|
||||
|
||||
# QKV projections
|
||||
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
|
||||
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
|
||||
|
||||
# QK normalization (Qwen3 specific)
|
||||
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
||||
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
||||
|
||||
# Rotary embeddings
|
||||
self.rotary_emb = Qwen3RotaryEmbedding(
|
||||
head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
|
||||
past_key_value: (k_cache, v_cache) from previous steps
|
||||
use_cache: Whether to return updated cache
|
||||
output_qkv: Whether to output Q, K, V tensors for debugging
|
||||
|
||||
Returns:
|
||||
output: [batch, seq_len, hidden_size]
|
||||
past_key_value: Updated cache (if use_cache=True)
|
||||
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
|
||||
"""
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# === QKV Projections ===
|
||||
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
|
||||
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
||||
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
||||
|
||||
# Reshape to [batch, seq_len, num_heads, head_dim]
|
||||
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||
|
||||
# === QK Normalization (Qwen3 specific) ===
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Transpose to [batch, num_heads, seq_len, head_dim]
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
# === Rotary Position Embeddings ===
|
||||
cos, sin = self.rotary_emb(v, position_ids)
|
||||
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||
|
||||
# === KV Cache Update ===
|
||||
if past_key_value is not None:
|
||||
k_cache, v_cache = past_key_value
|
||||
k = torch.cat([k_cache, k], dim=2)
|
||||
v = torch.cat([v_cache, v], dim=2)
|
||||
|
||||
new_past_key_value = (k, v) if use_cache else None
|
||||
|
||||
# === Grouped Query Attention (expand KV heads if needed) ===
|
||||
if self.num_kv_groups > 1:
|
||||
# Repeat KV for each query group
|
||||
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
||||
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
||||
|
||||
# === Attention Computation (using SDPA for memory efficiency) ===
|
||||
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
|
||||
# is_causal only works when q_len == kv_len (prefill), not during decode
|
||||
q_len, kv_len = q.shape[2], k.shape[2]
|
||||
is_causal = (q_len == kv_len) and (q_len > 1)
|
||||
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=is_causal,
|
||||
scale=self.scaling,
|
||||
) # [batch, num_heads, seq_len, head_dim]
|
||||
|
||||
# === Output Projection ===
|
||||
# Transpose back and reshape
|
||||
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
|
||||
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
|
||||
output = self.o_proj(attn_output)
|
||||
|
||||
# Optional QKV output for debugging
|
||||
qkv_dict = None
|
||||
if output_qkv:
|
||||
qkv_dict = {
|
||||
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
|
||||
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
|
||||
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
|
||||
}
|
||||
|
||||
return output, new_past_key_value, qkv_dict
|
||||
|
||||
|
||||
class Qwen3MLP(nn.Module):
|
||||
"""
|
||||
Qwen3 MLP with SwiGLU activation.
|
||||
|
||||
Data Flow:
|
||||
---------
|
||||
hidden_states [batch, seq_len, hidden_size]
|
||||
│
|
||||
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
|
||||
│
|
||||
└──► up_proj ──► up [batch, seq_len, intermediate_size]
|
||||
│
|
||||
▼
|
||||
silu(gate) * up
|
||||
│
|
||||
▼
|
||||
down_proj ──► output [batch, seq_len, hidden_size]
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
gate = self.gate_proj(x)
|
||||
up = self.up_proj(x)
|
||||
return self.down_proj(F.silu(gate) * up)
|
||||
|
||||
|
||||
class Qwen3DecoderLayer(nn.Module):
|
||||
"""Single Qwen3 Decoder Layer."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
layer_idx: int = 0,
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
# Pre-attention LayerNorm
|
||||
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
# Self-attention
|
||||
self.self_attn = Qwen3Attention(
|
||||
hidden_size=hidden_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
attention_bias=attention_bias,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
# Post-attention LayerNorm
|
||||
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
# MLP
|
||||
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv: bool = False,
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: Causal attention mask
|
||||
past_key_value: KV cache for this layer
|
||||
use_cache: Whether to return updated cache
|
||||
output_qkv: Whether to output Q, K, V for debugging
|
||||
|
||||
Returns:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
past_key_value: Updated cache
|
||||
qkv_dict: QKV tensors (if output_qkv=True)
|
||||
"""
|
||||
# === Self Attention Block ===
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
attn_output, new_past_key_value, qkv_dict = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_value,
|
||||
use_cache=use_cache,
|
||||
output_qkv=output_qkv,
|
||||
)
|
||||
|
||||
hidden_states = residual + attn_output
|
||||
|
||||
# === MLP Block ===
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
return hidden_states, new_past_key_value, qkv_dict
|
||||
|
||||
|
||||
class Qwen3Model(nn.Module):
|
||||
"""Qwen3 Transformer Model (without LM head)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_hidden_layers: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
|
||||
# Token embeddings
|
||||
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
|
||||
|
||||
# Decoder layers
|
||||
self.layers = nn.ModuleList([
|
||||
Qwen3DecoderLayer(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
mlp_bias=mlp_bias,
|
||||
layer_idx=i,
|
||||
)
|
||||
for i in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
# Final LayerNorm
|
||||
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv_layers: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
position_ids: [batch, seq_len]
|
||||
attention_mask: [batch, seq_len] or pre-computed 4D mask
|
||||
past_key_values: List of (k, v) tuples for each layer
|
||||
use_cache: Whether to return new cache
|
||||
output_qkv_layers: List of layer indices to output QKV for
|
||||
|
||||
Returns:
|
||||
hidden_states: [batch, seq_len, hidden_size]
|
||||
new_past_key_values: Updated cache
|
||||
qkv_outputs: {layer_idx: qkv_dict}
|
||||
"""
|
||||
batch_size, seq_len = input_ids.shape
|
||||
|
||||
# Embedding
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
|
||||
# Position IDs
|
||||
if position_ids is None:
|
||||
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
|
||||
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
|
||||
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# Attention mask (create causal mask if not provided)
|
||||
if attention_mask is None or attention_mask.dim() == 2:
|
||||
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
|
||||
causal_mask = torch.triu(
|
||||
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
|
||||
diagonal=kv_seq_len - seq_len + 1,
|
||||
)
|
||||
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
|
||||
|
||||
# Initialize cache list
|
||||
new_past_key_values = [] if use_cache else None
|
||||
qkv_outputs = {} if output_qkv_layers else None
|
||||
|
||||
# Decoder layers
|
||||
for i, layer in enumerate(self.layers):
|
||||
past_kv = past_key_values[i] if past_key_values else None
|
||||
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
|
||||
|
||||
hidden_states, new_kv, qkv_dict = layer(
|
||||
hidden_states=hidden_states,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_kv,
|
||||
use_cache=use_cache,
|
||||
output_qkv=output_qkv,
|
||||
)
|
||||
|
||||
if use_cache:
|
||||
new_past_key_values.append(new_kv)
|
||||
if qkv_dict is not None:
|
||||
qkv_outputs[i] = qkv_dict
|
||||
|
||||
# Final norm
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states, new_past_key_values, qkv_outputs
|
||||
|
||||
|
||||
class Qwen3ForCausalLM(nn.Module):
|
||||
"""Qwen3 Model with Language Modeling head."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
num_hidden_layers: int,
|
||||
num_attention_heads: int,
|
||||
num_key_value_heads: int,
|
||||
head_dim: int,
|
||||
max_position_embeddings: int = 32768,
|
||||
rope_theta: float = 10000.0,
|
||||
rms_norm_eps: float = 1e-6,
|
||||
attention_bias: bool = False,
|
||||
mlp_bias: bool = False,
|
||||
tie_word_embeddings: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.tie_word_embeddings = tie_word_embeddings
|
||||
|
||||
# Transformer model
|
||||
self.model = Qwen3Model(
|
||||
vocab_size=vocab_size,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_key_value_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
rope_theta=rope_theta,
|
||||
rms_norm_eps=rms_norm_eps,
|
||||
attention_bias=attention_bias,
|
||||
mlp_bias=mlp_bias,
|
||||
)
|
||||
|
||||
# LM head
|
||||
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||
use_cache: bool = False,
|
||||
output_qkv_layers: Optional[List[int]] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
||||
"""
|
||||
Args:
|
||||
input_ids: [batch, seq_len]
|
||||
... (same as Qwen3Model)
|
||||
|
||||
Returns:
|
||||
logits: [batch, seq_len, vocab_size]
|
||||
past_key_values: Updated KV cache
|
||||
qkv_outputs: QKV tensors for specified layers
|
||||
"""
|
||||
hidden_states, new_past_key_values, qkv_outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
attention_mask=attention_mask,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=use_cache,
|
||||
output_qkv_layers=output_qkv_layers,
|
||||
)
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
return logits, new_past_key_values, qkv_outputs
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
|
||||
"""
|
||||
Load weights from a pretrained Qwen3 model.
|
||||
|
||||
Args:
|
||||
model_path: Path to model directory containing config.json and model weights
|
||||
dtype: Data type for model weights
|
||||
|
||||
Returns:
|
||||
Initialized Qwen3ForCausalLM model
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# Load config
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Create model
|
||||
model = cls(
|
||||
vocab_size=config["vocab_size"],
|
||||
hidden_size=config["hidden_size"],
|
||||
intermediate_size=config["intermediate_size"],
|
||||
num_hidden_layers=config["num_hidden_layers"],
|
||||
num_attention_heads=config["num_attention_heads"],
|
||||
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
|
||||
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
|
||||
max_position_embeddings=config.get("max_position_embeddings", 32768),
|
||||
rope_theta=config.get("rope_theta", 10000.0),
|
||||
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
|
||||
attention_bias=config.get("attention_bias", False),
|
||||
mlp_bias=config.get("mlp_bias", False),
|
||||
tie_word_embeddings=config.get("tie_word_embeddings", True),
|
||||
)
|
||||
|
||||
# Load weights
|
||||
weight_files = sorted([
|
||||
f for f in os.listdir(model_path)
|
||||
if f.endswith(".safetensors")
|
||||
])
|
||||
|
||||
state_dict = {}
|
||||
for wf in weight_files:
|
||||
state_dict.update(load_file(os.path.join(model_path, wf)))
|
||||
|
||||
# Load into model
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
# Tie lm_head weights to embed_tokens if configured
|
||||
if model.tie_word_embeddings:
|
||||
model.lm_head.weight = model.model.embed_tokens.weight
|
||||
|
||||
model = model.to(dtype)
|
||||
|
||||
return model
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
max_new_tokens: int = 32,
|
||||
temperature: float = 1.0,
|
||||
do_sample: bool = True,
|
||||
pad_token_id: Optional[int] = None,
|
||||
eos_token_id: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Simple autoregressive generation."""
|
||||
device = input_ids.device
|
||||
batch_size, seq_len = input_ids.shape
|
||||
past_key_values = None
|
||||
generated = input_ids.clone()
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if past_key_values is None:
|
||||
current_input = generated
|
||||
else:
|
||||
current_input = generated[:, -1:]
|
||||
|
||||
logits, past_key_values, _ = self(
|
||||
input_ids=current_input,
|
||||
past_key_values=past_key_values,
|
||||
use_cache=True,
|
||||
)
|
||||
|
||||
next_token_logits = logits[:, -1, :]
|
||||
if temperature > 0 and do_sample:
|
||||
next_token_logits = next_token_logits / temperature
|
||||
probs = torch.softmax(next_token_logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
|
||||
|
||||
generated = torch.cat([generated, next_token], dim=1)
|
||||
|
||||
if eos_token_id is not None and (next_token == eos_token_id).all():
|
||||
break
|
||||
|
||||
return generated
|
||||
|
||||
|
||||
def print_computation_graph():
|
||||
"""Print the computation graph for reference."""
|
||||
print(__doc__)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print_computation_graph()
|
||||
365
tests/test_align.py
Normal file
365
tests/test_align.py
Normal file
@@ -0,0 +1,365 @@
|
||||
"""
|
||||
Test alignment between nanovllm and custom torch Qwen3 implementation.
|
||||
Compares attention layer outputs and QKV tensors to verify correctness.
|
||||
|
||||
Usage:
|
||||
python test_align.py # Without CPU offload
|
||||
python test_align.py --enable-offload # With CPU offload
|
||||
python test_align.py --input-len 4096 # Custom input length
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from nanovllm import LLM, SamplingParams
|
||||
from modeling_qwen3 import Qwen3ForCausalLM
|
||||
from utils import generate_needle_prompt
|
||||
|
||||
# Parse arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--enable-offload", action="store_true", help="Enable CPU offload")
|
||||
parser.add_argument("--input-len", type=int, default=1024 * 12, help="Input sequence length")
|
||||
parser.add_argument("--model-path", type=str, default="~/models/Qwen3-0.6B/", help="Model path")
|
||||
parser.add_argument("--num-gpu-blocks", type=int, default=6, help="Number of GPU blocks (ring buffer slots)")
|
||||
parser.add_argument("--block-size", type=int, default=1024, help="KV cache block size")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Config
|
||||
MODEL_PATH = os.path.expanduser(args.model_path)
|
||||
INPUT_LEN = args.input_len
|
||||
ENABLE_OFFLOAD = args.enable_offload
|
||||
NUM_GPU_BLOCKS = args.num_gpu_blocks
|
||||
BLOCK_SIZE = args.block_size
|
||||
DTYPE = torch.float16
|
||||
|
||||
print(f"Config: input_len={INPUT_LEN}, enable_offload={ENABLE_OFFLOAD}, num_gpu_blocks={NUM_GPU_BLOCKS}, block_size={BLOCK_SIZE}")
|
||||
|
||||
# Storage for captured tensors
|
||||
nanovllm_outputs = {}
|
||||
torch_outputs = {}
|
||||
nanovllm_qkv = {}
|
||||
nanovllm_proj_inputs = {}
|
||||
torch_proj_inputs = {}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Hook functions for non-offload mode (overwrite)
|
||||
# ============================================================
|
||||
def make_nanovllm_hook(layer_id: int, storage: dict):
|
||||
def hook(module, inputs, output):
|
||||
attn_output = output[0] if isinstance(output, tuple) else output
|
||||
if attn_output.dim() == 2:
|
||||
attn_output = attn_output.unsqueeze(0)
|
||||
storage[layer_id] = attn_output.detach().clone()
|
||||
return hook
|
||||
|
||||
|
||||
def make_nanovllm_qkv_hook(layer_id: int, storage: dict):
|
||||
def hook(module, inputs):
|
||||
q, k, v = inputs[0], inputs[1], inputs[2]
|
||||
storage[layer_id] = {
|
||||
"q": q.detach().clone(),
|
||||
"k": k.detach().clone(),
|
||||
"v": v.detach().clone(),
|
||||
}
|
||||
return hook
|
||||
|
||||
|
||||
def make_proj_input_hook(layer_id: int, storage: dict):
|
||||
def hook(module, inputs):
|
||||
hidden = inputs[0]
|
||||
if hidden.dim() == 2:
|
||||
hidden = hidden.unsqueeze(0)
|
||||
storage[layer_id] = hidden.detach().clone()
|
||||
return hook
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Hook functions for offload mode (accumulate Q and I, overwrite O)
|
||||
# ============================================================
|
||||
def make_accumulating_q_hook(layer_id: int, storage: dict):
|
||||
"""Accumulate Q from each chunk for offload mode."""
|
||||
def hook(module, inputs):
|
||||
q = inputs[0].detach().clone()
|
||||
if layer_id not in storage:
|
||||
storage[layer_id] = []
|
||||
storage[layer_id].append(q)
|
||||
return hook
|
||||
|
||||
|
||||
def make_accumulating_input_hook(layer_id: int, storage: dict):
|
||||
"""Accumulate input hidden states from each chunk for offload mode."""
|
||||
def hook(module, inputs):
|
||||
hidden = inputs[0].detach().clone()
|
||||
if layer_id not in storage:
|
||||
storage[layer_id] = []
|
||||
storage[layer_id].append(hidden)
|
||||
return hook
|
||||
|
||||
|
||||
def make_overwrite_output_hook(layer_id: int, storage: dict):
|
||||
"""Overwrite output (keep only last chunk) for offload mode."""
|
||||
def hook(module, inputs, output):
|
||||
attn_output = output[0] if isinstance(output, tuple) else output
|
||||
if attn_output.dim() == 2:
|
||||
attn_output = attn_output.unsqueeze(0)
|
||||
storage[layer_id] = attn_output.detach().clone()
|
||||
return hook
|
||||
|
||||
|
||||
# ============================================================
|
||||
# CPU KV cache access for offload mode
|
||||
# ============================================================
|
||||
def get_nanovllm_kv_from_cpu(llm, seq, num_layers):
|
||||
"""Get complete K, V cache from CPU side after all chunks finish."""
|
||||
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
||||
kvcache_manager = llm.model_runner.kvcache_manager
|
||||
|
||||
# CRITICAL: Synchronize all CUDA operations before reading CPU memory
|
||||
# The D2H copy runs on transfer_stream_main and may still be in progress
|
||||
torch.cuda.synchronize()
|
||||
|
||||
cpu_block_ids = kvcache_manager.get_cpu_block_table(seq)
|
||||
|
||||
kv_per_layer = {}
|
||||
for layer_id in range(num_layers):
|
||||
k_blocks = []
|
||||
v_blocks = []
|
||||
for cpu_block_id in cpu_block_ids:
|
||||
k_block, v_block = offload_engine.get_cpu_block(layer_id, cpu_block_id)
|
||||
k_blocks.append(k_block)
|
||||
v_blocks.append(v_block)
|
||||
|
||||
# Concatenate all blocks: [total_tokens, kv_heads, head_dim]
|
||||
k_full = torch.cat(k_blocks, dim=0)[:seq.num_tokens]
|
||||
v_full = torch.cat(v_blocks, dim=0)[:seq.num_tokens]
|
||||
kv_per_layer[layer_id] = {"k": k_full, "v": v_full}
|
||||
|
||||
return kv_per_layer
|
||||
|
||||
|
||||
def make_torch_hook(layer_id: int, storage: dict):
|
||||
def hook(module, inputs, output):
|
||||
storage[layer_id] = output[0].detach().clone()
|
||||
return hook
|
||||
|
||||
|
||||
def cosine_sim(t1: torch.Tensor, t2: torch.Tensor) -> float:
|
||||
"""Cosine similarity between flattened tensors (1.0 = identical)."""
|
||||
return torch.nn.functional.cosine_similarity(
|
||||
t1.flatten().float(), t2.flatten().float(), dim=0
|
||||
).item()
|
||||
|
||||
|
||||
def compute_qkv_sims(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int):
|
||||
"""Compute Q, K, V cosine similarities. Returns (q_sim, k_sim, v_sim)."""
|
||||
nano_q = nano_qkv["q"]
|
||||
torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1)
|
||||
|
||||
nano_k = nano_qkv["k"]
|
||||
torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
||||
|
||||
nano_v = nano_qkv["v"]
|
||||
torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
||||
|
||||
return cosine_sim(nano_q, torch_q), cosine_sim(nano_k, torch_k), cosine_sim(nano_v, torch_v)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Load models
|
||||
# ============================================================
|
||||
print("Loading nanovllm model...")
|
||||
llm_kwargs = dict(
|
||||
enforce_eager=True,
|
||||
max_model_len=32768,
|
||||
gpu_memory_utilization=0.2,
|
||||
max_num_batched_tokens=32768,
|
||||
enable_cpu_offload=ENABLE_OFFLOAD,
|
||||
dtype="float16",
|
||||
kvcache_block_size=BLOCK_SIZE,
|
||||
)
|
||||
if ENABLE_OFFLOAD:
|
||||
llm_kwargs["num_gpu_blocks"] = NUM_GPU_BLOCKS
|
||||
llm = LLM(MODEL_PATH, **llm_kwargs)
|
||||
|
||||
num_heads = llm.model_runner.model.model.layers[0].self_attn.num_heads
|
||||
num_kv_heads = llm.model_runner.model.model.layers[0].self_attn.num_kv_heads
|
||||
num_kv_groups = num_heads // num_kv_heads
|
||||
num_layers = len(llm.model_runner.model.model.layers)
|
||||
|
||||
print("Loading torch model...")
|
||||
torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
|
||||
torch_model = torch_model.to("cuda")
|
||||
torch_model.eval()
|
||||
|
||||
# ============================================================
|
||||
# Generate test input
|
||||
# ============================================================
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
prompt, _ = generate_needle_prompt(tokenizer=tokenizer, target_length=INPUT_LEN, verbose=True)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
|
||||
# ============================================================
|
||||
# Register hooks
|
||||
# ============================================================
|
||||
nanovllm_hooks = []
|
||||
nanovllm_q_accum = {} # For offload mode: accumulated Q from all chunks
|
||||
nanovllm_i_accum = {} # For offload mode: accumulated I from all chunks
|
||||
|
||||
for layer_idx, layer in enumerate(llm.model_runner.model.model.layers):
|
||||
if ENABLE_OFFLOAD:
|
||||
# Offload mode: accumulate Q and I, overwrite O
|
||||
nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_overwrite_output_hook(layer_idx, nanovllm_outputs)))
|
||||
nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_accumulating_q_hook(layer_idx, nanovllm_q_accum)))
|
||||
nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_accumulating_input_hook(layer_idx, nanovllm_i_accum)))
|
||||
else:
|
||||
# Non-offload mode: overwrite all
|
||||
nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_nanovllm_hook(layer_idx, nanovllm_outputs)))
|
||||
nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_nanovllm_qkv_hook(layer_idx, nanovllm_qkv)))
|
||||
nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, nanovllm_proj_inputs)))
|
||||
|
||||
torch_hooks = []
|
||||
for layer_idx, layer in enumerate(torch_model.model.layers):
|
||||
torch_hooks.append(layer.self_attn.register_forward_hook(make_torch_hook(layer_idx, torch_outputs)))
|
||||
torch_hooks.append(layer.self_attn.q_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, torch_proj_inputs)))
|
||||
|
||||
# ============================================================
|
||||
# Run inference
|
||||
# ============================================================
|
||||
print("Running nanovllm inference...")
|
||||
|
||||
if ENABLE_OFFLOAD:
|
||||
# Manual execution to capture KV cache before deallocation
|
||||
# Use max_tokens=2 so sequence doesn't finish immediately after prefill
|
||||
llm.add_request(input_ids[0].tolist(), SamplingParams(temperature=0.01, max_tokens=2))
|
||||
|
||||
# Run prefill step (this calls run_chunked_offload_prefill internally)
|
||||
output, num_tokens = llm.step()
|
||||
print(f"[Offload] Prefill done: {num_tokens} tokens")
|
||||
|
||||
# Now seq is in running queue, KV cache is in CPU
|
||||
seq = llm.scheduler.running[0]
|
||||
print(f"[Offload] Sequence: {seq}")
|
||||
|
||||
# Get KV cache from CPU BEFORE decode step deallocates it
|
||||
nanovllm_kv_cpu = get_nanovllm_kv_from_cpu(llm, seq, num_layers)
|
||||
print(f"[Offload] Retrieved KV cache from CPU for {seq.num_tokens} tokens")
|
||||
|
||||
# IMPORTANT: Save outputs NOW before decode step overwrites them
|
||||
# nanovllm_outputs contains prefill attention outputs at this point
|
||||
nanovllm_outputs_prefill = {k: v.clone() for k, v in nanovllm_outputs.items()}
|
||||
|
||||
# Complete remaining steps (decode)
|
||||
while not llm.is_finished():
|
||||
llm.step()
|
||||
|
||||
# Use prefill outputs for comparison
|
||||
nanovllm_outputs = nanovllm_outputs_prefill
|
||||
else:
|
||||
nanovllm_result = llm.generate([input_ids[0].tolist()], SamplingParams(temperature=0.01, max_tokens=1), use_tqdm=False)
|
||||
|
||||
print("Running torch inference...")
|
||||
with torch.no_grad():
|
||||
torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers)))
|
||||
|
||||
# ============================================================
|
||||
# Compare using cosine similarity (1.0 = perfect alignment)
|
||||
# ============================================================
|
||||
print("\n" + "=" * 70)
|
||||
print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}")
|
||||
print("=" * 70)
|
||||
|
||||
all_passed = True
|
||||
threshold = 0.999 # Cosine similarity threshold
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
if ENABLE_OFFLOAD:
|
||||
# ============================================================
|
||||
# Offload mode: use accumulated Q/I and CPU-side K/V
|
||||
# Only compare prompt tokens (INPUT_LEN), exclude generated tokens
|
||||
# ============================================================
|
||||
# I: concatenate accumulated chunks, trim to prompt length
|
||||
i_chunks = nanovllm_i_accum[layer_idx]
|
||||
nano_in = torch.cat(i_chunks, dim=0)[:INPUT_LEN]
|
||||
if nano_in.dim() == 2:
|
||||
nano_in = nano_in.unsqueeze(0)
|
||||
torch_in = torch_proj_inputs[layer_idx]
|
||||
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
|
||||
torch_in = torch_in.view(nano_in.shape)
|
||||
i_sim = cosine_sim(nano_in, torch_in)
|
||||
|
||||
# Q: concatenate accumulated chunks, trim to prompt length
|
||||
q_chunks = nanovllm_q_accum[layer_idx]
|
||||
nano_q = torch.cat(q_chunks, dim=0)[:INPUT_LEN]
|
||||
torch_q = torch_qkv_outputs[layer_idx]["q"].squeeze(0).transpose(0, 1)
|
||||
q_sim = cosine_sim(nano_q, torch_q)
|
||||
|
||||
# K, V: from CPU cache, trim to prompt length and move to GPU
|
||||
nano_k = nanovllm_kv_cpu[layer_idx]["k"][:INPUT_LEN].cuda()
|
||||
torch_k = torch_qkv_outputs[layer_idx]["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
||||
k_sim = cosine_sim(nano_k, torch_k)
|
||||
|
||||
nano_v = nanovllm_kv_cpu[layer_idx]["v"][:INPUT_LEN].cuda()
|
||||
torch_v = torch_qkv_outputs[layer_idx]["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
|
||||
v_sim = cosine_sim(nano_v, torch_v)
|
||||
|
||||
# O: compare attention outputs directly
|
||||
# For single-chunk case (input_len <= block_size), shapes should match
|
||||
# For multi-chunk case, nano_out is the last chunk only
|
||||
nano_out = nanovllm_outputs[layer_idx]
|
||||
torch_out = torch_outputs[layer_idx]
|
||||
|
||||
if nano_out.numel() == torch_out.numel():
|
||||
# Single chunk or shapes match - compare directly
|
||||
o_sim = cosine_sim(nano_out, torch_out)
|
||||
else:
|
||||
# Multi-chunk case: compare last chunk with corresponding torch slice
|
||||
last_chunk_len = nano_out.shape[1] if nano_out.dim() == 3 else nano_out.shape[0]
|
||||
torch_out_slice = torch_out[:, -last_chunk_len:, :] if torch_out.dim() == 3 else torch_out[-last_chunk_len:, :]
|
||||
o_sim = cosine_sim(nano_out, torch_out_slice)
|
||||
else:
|
||||
# ============================================================
|
||||
# Non-offload mode: original logic
|
||||
# ============================================================
|
||||
# Input similarity
|
||||
nano_in = nanovllm_proj_inputs[layer_idx]
|
||||
torch_in = torch_proj_inputs[layer_idx]
|
||||
|
||||
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
|
||||
torch_in = torch_in.view(nano_in.shape)
|
||||
|
||||
i_sim = cosine_sim(nano_in, torch_in)
|
||||
|
||||
# QKV similarities
|
||||
q_sim, k_sim, v_sim = compute_qkv_sims(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups)
|
||||
|
||||
# O similarity
|
||||
nano_out = nanovllm_outputs[layer_idx]
|
||||
torch_out = torch_outputs[layer_idx]
|
||||
if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel():
|
||||
torch_out = torch_out.view(nano_out.shape)
|
||||
o_sim = cosine_sim(nano_out, torch_out)
|
||||
|
||||
# Check pass/fail
|
||||
passed = all(s >= threshold for s in [i_sim, q_sim, k_sim, v_sim, o_sim])
|
||||
all_passed = all_passed and passed
|
||||
status = "" if passed else " *"
|
||||
|
||||
print(f"Layer {layer_idx:2d}{status:<3} {i_sim:>10.6f} {q_sim:>10.6f} {k_sim:>10.6f} {v_sim:>10.6f} {o_sim:>10.6f}")
|
||||
|
||||
# ============================================================
|
||||
# Cleanup and result
|
||||
# ============================================================
|
||||
for hook in nanovllm_hooks + torch_hooks:
|
||||
hook.remove()
|
||||
|
||||
print("=" * 70)
|
||||
mode_str = " [offload]" if ENABLE_OFFLOAD else ""
|
||||
if all_passed:
|
||||
print(f"test_align{mode_str}: PASSED (cosine_sim >= 0.999)")
|
||||
else:
|
||||
print(f"test_align{mode_str}: FAILED (* = cosine_sim < 0.999)")
|
||||
134
tests/test_nanovllm_steppable.py
Normal file
134
tests/test_nanovllm_steppable.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Test NanovllmSteppable: Print activation statistics at each layer.
|
||||
|
||||
Usage:
|
||||
python tests/test_nanovllm_steppable.py
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from nanovllm import LLM
|
||||
from nanovllm.debug.adapters.nanovllm_adapter import NanovllmSteppable
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
# ============================================================
|
||||
# Config
|
||||
# ============================================================
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||
INPUT_LEN = 32768 # Longer context to test offload
|
||||
MAX_NEW_TOKENS = 20
|
||||
DTYPE = torch.float16
|
||||
ENABLE_CPU_OFFLOAD = True # Test offload mode
|
||||
|
||||
# ============================================================
|
||||
# Load Model
|
||||
# ============================================================
|
||||
print(f"Loading nanovllm model (cpu_offload={ENABLE_CPU_OFFLOAD})...")
|
||||
llm = LLM(
|
||||
MODEL_PATH,
|
||||
enforce_eager=True, # Required for hooks to work
|
||||
max_model_len=40960,
|
||||
max_num_batched_tokens=40960,
|
||||
enable_cpu_offload=ENABLE_CPU_OFFLOAD,
|
||||
dtype="float16",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
|
||||
# Get the underlying model for steppable
|
||||
model = llm.model_runner.model
|
||||
|
||||
# ============================================================
|
||||
# Prepare Input (using needle-in-haystack prompt)
|
||||
# ============================================================
|
||||
prompt, expected_answer = generate_needle_prompt(
|
||||
tokenizer,
|
||||
target_length=INPUT_LEN,
|
||||
needle_position=0.5,
|
||||
needle_value="7492",
|
||||
use_chat_template=False,
|
||||
verbose=True,
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
print(f"Expected answer: {expected_answer}\n")
|
||||
|
||||
# ============================================================
|
||||
# Create Steppable Model (reused for prefill + decode)
|
||||
# ============================================================
|
||||
steppable = NanovllmSteppable(model)
|
||||
|
||||
# ============================================================
|
||||
# Prefill Phase: Print activation stats
|
||||
# ============================================================
|
||||
print("=" * 85)
|
||||
print("PREFILL PHASE")
|
||||
print("=" * 85)
|
||||
print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
|
||||
print("-" * 85)
|
||||
|
||||
current_ids = input_ids.clone()
|
||||
logits = None
|
||||
|
||||
for bp in steppable.step(current_ids, is_prefill=True):
|
||||
t = bp.tensor.float()
|
||||
shape_str = str(list(t.shape))
|
||||
print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}")
|
||||
if bp.name == "LM Head":
|
||||
logits = bp.tensor
|
||||
|
||||
# Get first token from prefill
|
||||
next_token_id = logits[0, -1].argmax().item()
|
||||
next_token = tokenizer.decode(next_token_id)
|
||||
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
|
||||
generated_tokens = [next_token]
|
||||
|
||||
# ============================================================
|
||||
# Decode Phase: Only print generated tokens
|
||||
# ============================================================
|
||||
print("\n" + "=" * 85)
|
||||
print("DECODE PHASE")
|
||||
print("=" * 85)
|
||||
print(f"Step 1: {next_token!r}")
|
||||
|
||||
for step in range(2, MAX_NEW_TOKENS + 1):
|
||||
# Forward pass with full sequence (reuse same steppable)
|
||||
# Note: nanovllm without KV cache needs full sequence for each decode
|
||||
for bp in steppable.step(current_ids, is_prefill=True):
|
||||
if bp.name == "LM Head":
|
||||
logits = bp.tensor
|
||||
|
||||
# Get next token (greedy)
|
||||
next_token_id = logits[0, -1].argmax().item()
|
||||
next_token = tokenizer.decode(next_token_id)
|
||||
generated_tokens.append(next_token)
|
||||
|
||||
print(f"Step {step:2}: {next_token!r}")
|
||||
|
||||
# Stop if EOS
|
||||
if next_token_id == tokenizer.eos_token_id:
|
||||
print(" (EOS)")
|
||||
break
|
||||
|
||||
# Append to sequence
|
||||
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
|
||||
|
||||
# ============================================================
|
||||
# Result
|
||||
# ============================================================
|
||||
print("\n" + "=" * 85)
|
||||
print("RESULT")
|
||||
print("=" * 85)
|
||||
generated_text = "".join(generated_tokens)
|
||||
print(f"Generated: {generated_text!r}")
|
||||
print(f"Expected: {expected_answer}")
|
||||
print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}")
|
||||
|
||||
print("\ntest_nanovllm_steppable: PASSED")
|
||||
@@ -8,155 +8,11 @@ sequences longer than ~200 tokens. Use --no-offload for correctness testing.
|
||||
"""
|
||||
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||
|
||||
import argparse
|
||||
from nanovllm import LLM, SamplingParams
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Needle Test Generator
|
||||
# ============================================================
|
||||
|
||||
def generate_needle_prompt(
|
||||
tokenizer,
|
||||
target_length: int,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
use_chat_template: bool = True,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
||||
|
||||
Args:
|
||||
tokenizer: HuggingFace tokenizer for length estimation
|
||||
target_length: Target total sequence length in tokens
|
||||
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
|
||||
needle_value: The secret value to hide in the haystack
|
||||
use_chat_template: Whether to use chat template for instruct models
|
||||
|
||||
Returns:
|
||||
(prompt, expected_answer): The full prompt and the expected needle value
|
||||
"""
|
||||
# Haystack filler paragraphs (various topics to create realistic context)
|
||||
haystack_paragraphs = [
|
||||
"The weather today is quite pleasant with clear skies and moderate temperatures. "
|
||||
"Many people are enjoying outdoor activities in the park. "
|
||||
"Birds are singing in the trees and children are playing on the swings. ",
|
||||
|
||||
"In the world of technology, new innovations continue to emerge every day. "
|
||||
"Researchers are working on advanced algorithms and computing systems. "
|
||||
"The future of artificial intelligence looks promising with many breakthroughs. ",
|
||||
|
||||
"The history of human civilization spans thousands of years. "
|
||||
"Ancient cultures developed writing, mathematics, and astronomy. "
|
||||
"Trade routes connected distant lands and facilitated cultural exchange. ",
|
||||
|
||||
"Modern cooking combines traditional techniques with new ingredients. "
|
||||
"Chefs around the world experiment with flavors and presentations. "
|
||||
"Food brings people together and creates memorable experiences. ",
|
||||
|
||||
"The ocean covers more than seventy percent of Earth's surface. "
|
||||
"Marine ecosystems support an incredible diversity of life forms. "
|
||||
"Scientists continue to discover new species in the deep sea. ",
|
||||
|
||||
"Music has been a part of human culture since prehistoric times. "
|
||||
"Different genres evolved across various regions and time periods. "
|
||||
"Today, people can access millions of songs through digital platforms. ",
|
||||
|
||||
"Space exploration has revealed many secrets about our universe. "
|
||||
"Telescopes can observe galaxies billions of light years away. "
|
||||
"Future missions aim to establish human presence on other planets. ",
|
||||
|
||||
"The study of languages reveals patterns in human cognition. "
|
||||
"Linguists analyze grammar, semantics, and phonetics across cultures. "
|
||||
"Language continues to evolve with new words and expressions. ",
|
||||
]
|
||||
|
||||
# The needle sentence
|
||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||
|
||||
# Question at the end
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
|
||||
# Estimate tokens for fixed parts
|
||||
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
||||
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
||||
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
||||
# Buffer for chat template, special tokens, etc.
|
||||
overhead_tokens = 100 if use_chat_template else 50
|
||||
|
||||
# Available tokens for haystack
|
||||
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
||||
if haystack_target_tokens < 100:
|
||||
raise ValueError(f"target_length {target_length} is too short for needle test")
|
||||
|
||||
# Build haystack by repeating paragraphs
|
||||
haystack_parts = []
|
||||
current_tokens = 0
|
||||
para_idx = 0
|
||||
|
||||
while current_tokens < haystack_target_tokens:
|
||||
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
|
||||
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
||||
if current_tokens + para_tokens > haystack_target_tokens:
|
||||
break
|
||||
haystack_parts.append(para)
|
||||
current_tokens += para_tokens
|
||||
para_idx += 1
|
||||
|
||||
# Calculate needle insertion point
|
||||
needle_idx = int(len(haystack_parts) * needle_position)
|
||||
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
||||
|
||||
# Insert needle
|
||||
haystack_parts.insert(needle_idx, needle)
|
||||
|
||||
# Assemble prompt
|
||||
full_text = "".join(haystack_parts)
|
||||
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
# Use chat template for instruct models
|
||||
# For Qwen3, add /no_think to disable thinking mode
|
||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||
messages = [
|
||||
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
# Raw text format for base models
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
prompt = full_text + question
|
||||
|
||||
# Verify length
|
||||
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
||||
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
||||
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
||||
|
||||
return prompt, needle_value
|
||||
|
||||
|
||||
def check_needle_answer(output_text: str, expected: str) -> bool:
|
||||
"""Check if the model output contains the expected needle value."""
|
||||
import re
|
||||
# Clean output - remove special tokens and whitespace
|
||||
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
||||
output_clean = ' '.join(output_clean.split()).lower()
|
||||
expected_clean = expected.strip().lower()
|
||||
|
||||
# Check if expected value appears in output
|
||||
# Also try to find it as a standalone number
|
||||
if expected_clean in output_clean:
|
||||
return True
|
||||
|
||||
# Try to extract numbers and check if expected is among them
|
||||
numbers = re.findall(r'\d+', output_clean)
|
||||
return expected_clean in numbers
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
# ============================================================
|
||||
@@ -168,6 +24,7 @@ def run_needle_test(
|
||||
max_model_len: int,
|
||||
input_len: int,
|
||||
num_gpu_blocks: int = 4,
|
||||
block_size: int = 1024,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
max_new_tokens: int = 32,
|
||||
@@ -182,6 +39,7 @@ def run_needle_test(
|
||||
max_model_len: Maximum model context length
|
||||
input_len: Target input sequence length
|
||||
num_gpu_blocks: Number of GPU blocks for offload
|
||||
block_size: KV cache block size
|
||||
needle_position: Where to place needle (0.0-1.0)
|
||||
needle_value: The secret value to find
|
||||
max_new_tokens: Maximum tokens to generate
|
||||
@@ -198,6 +56,7 @@ def run_needle_test(
|
||||
print(f"Model: {model_path}")
|
||||
print(f"Max model len: {max_model_len}")
|
||||
print(f"Input length: {input_len}")
|
||||
print(f"Block size: {block_size}")
|
||||
print(f"Needle position: {needle_position:.0%}")
|
||||
print(f"Needle value: {needle_value}")
|
||||
print(f"CPU offload: {enable_cpu_offload}")
|
||||
@@ -209,6 +68,7 @@ def run_needle_test(
|
||||
"max_model_len": max_model_len,
|
||||
"max_num_batched_tokens": max_model_len,
|
||||
"enable_cpu_offload": enable_cpu_offload,
|
||||
"kvcache_block_size": block_size,
|
||||
}
|
||||
if enable_cpu_offload:
|
||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||
@@ -263,7 +123,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--max-model-len",
|
||||
type=int,
|
||||
default=32 * 1024,
|
||||
default=36 * 1024,
|
||||
help="Maximum model context length"
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -278,6 +138,12 @@ if __name__ == "__main__":
|
||||
default=2,
|
||||
help="Number of GPU blocks for CPU offload"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block-size",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="KV cache block size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--needle-position",
|
||||
type=float,
|
||||
@@ -308,6 +174,7 @@ if __name__ == "__main__":
|
||||
max_model_len=args.max_model_len,
|
||||
input_len=args.input_len,
|
||||
num_gpu_blocks=args.num_gpu_blocks,
|
||||
block_size=args.block_size,
|
||||
needle_position=args.needle_position,
|
||||
needle_value=args.needle_value,
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
|
||||
@@ -8,148 +8,9 @@ Uses standard HuggingFace inference (no custom KV cache, no offload).
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Needle Test Generator
|
||||
# ============================================================
|
||||
|
||||
def generate_needle_prompt(
|
||||
tokenizer,
|
||||
target_length: int,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
use_chat_template: bool = True,
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
||||
|
||||
Args:
|
||||
tokenizer: HuggingFace tokenizer for length estimation
|
||||
target_length: Target total sequence length in tokens
|
||||
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
|
||||
needle_value: The secret value to hide in the haystack
|
||||
use_chat_template: Whether to use chat template for instruct models
|
||||
|
||||
Returns:
|
||||
(prompt, expected_answer): The full prompt and the expected needle value
|
||||
"""
|
||||
# Haystack filler paragraphs (various topics to create realistic context)
|
||||
haystack_paragraphs = [
|
||||
"The weather today is quite pleasant with clear skies and moderate temperatures. "
|
||||
"Many people are enjoying outdoor activities in the park. "
|
||||
"Birds are singing in the trees and children are playing on the swings. ",
|
||||
|
||||
"In the world of technology, new innovations continue to emerge every day. "
|
||||
"Researchers are working on advanced algorithms and computing systems. "
|
||||
"The future of artificial intelligence looks promising with many breakthroughs. ",
|
||||
|
||||
"The history of human civilization spans thousands of years. "
|
||||
"Ancient cultures developed writing, mathematics, and astronomy. "
|
||||
"Trade routes connected distant lands and facilitated cultural exchange. ",
|
||||
|
||||
"Modern cooking combines traditional techniques with new ingredients. "
|
||||
"Chefs around the world experiment with flavors and presentations. "
|
||||
"Food brings people together and creates memorable experiences. ",
|
||||
|
||||
"The ocean covers more than seventy percent of Earth's surface. "
|
||||
"Marine ecosystems support an incredible diversity of life forms. "
|
||||
"Scientists continue to discover new species in the deep sea. ",
|
||||
|
||||
"Music has been a part of human culture since prehistoric times. "
|
||||
"Different genres evolved across various regions and time periods. "
|
||||
"Today, people can access millions of songs through digital platforms. ",
|
||||
|
||||
"Space exploration has revealed many secrets about our universe. "
|
||||
"Telescopes can observe galaxies billions of light years away. "
|
||||
"Future missions aim to establish human presence on other planets. ",
|
||||
|
||||
"The study of languages reveals patterns in human cognition. "
|
||||
"Linguists analyze grammar, semantics, and phonetics across cultures. "
|
||||
"Language continues to evolve with new words and expressions. ",
|
||||
]
|
||||
|
||||
# The needle sentence
|
||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||
|
||||
# Estimate tokens for fixed parts
|
||||
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
||||
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
||||
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
||||
# Buffer for chat template, special tokens, etc.
|
||||
overhead_tokens = 100 if use_chat_template else 50
|
||||
|
||||
# Available tokens for haystack
|
||||
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
||||
if haystack_target_tokens < 100:
|
||||
raise ValueError(f"target_length {target_length} is too short for needle test")
|
||||
|
||||
# Build haystack by repeating paragraphs
|
||||
haystack_parts = []
|
||||
current_tokens = 0
|
||||
para_idx = 0
|
||||
|
||||
while current_tokens < haystack_target_tokens:
|
||||
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
|
||||
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
||||
if current_tokens + para_tokens > haystack_target_tokens:
|
||||
break
|
||||
haystack_parts.append(para)
|
||||
current_tokens += para_tokens
|
||||
para_idx += 1
|
||||
|
||||
# Calculate needle insertion point
|
||||
needle_idx = int(len(haystack_parts) * needle_position)
|
||||
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
||||
|
||||
# Insert needle
|
||||
haystack_parts.insert(needle_idx, needle)
|
||||
|
||||
# Assemble prompt
|
||||
full_text = "".join(haystack_parts)
|
||||
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
# Use chat template for instruct models
|
||||
# For Qwen3, add /no_think to disable thinking mode
|
||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||
messages = [
|
||||
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
||||
]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
else:
|
||||
# Raw text format for base models
|
||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
prompt = full_text + question
|
||||
|
||||
# Verify length
|
||||
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
||||
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
||||
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
||||
|
||||
return prompt, needle_value
|
||||
|
||||
|
||||
def check_needle_answer(output_text: str, expected: str) -> bool:
|
||||
"""Check if the model output contains the expected needle value."""
|
||||
import re
|
||||
# Clean output - remove special tokens and whitespace
|
||||
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
||||
output_clean = ' '.join(output_clean.split()).lower()
|
||||
expected_clean = expected.strip().lower()
|
||||
|
||||
# Check if expected value appears in output
|
||||
if expected_clean in output_clean:
|
||||
return True
|
||||
|
||||
# Try to extract numbers and check if expected is among them
|
||||
numbers = re.findall(r'\d+', output_clean)
|
||||
return expected_clean in numbers
|
||||
from transformers import AutoTokenizer
|
||||
from modeling_qwen3 import Qwen3ForCausalLM
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
|
||||
# ============================================================
|
||||
@@ -207,22 +68,19 @@ def run_needle_test(
|
||||
# 3. Load model
|
||||
print("[3/4] Loading model...")
|
||||
torch_dtype = {
|
||||
"auto": "auto",
|
||||
"auto": torch.float16, # default to float16 for custom model
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}.get(dtype, "auto")
|
||||
}.get(dtype, torch.float16)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
|
||||
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.eval()
|
||||
|
||||
# 4. Generate output
|
||||
print("[4/4] Running inference...")
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
|
||||
device = next(model.parameters()).device
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
||||
print(f" Input shape: {input_ids.shape}")
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
121
tests/test_torch_steppable.py
Normal file
121
tests/test_torch_steppable.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""
|
||||
Test TorchSteppable: Print activation statistics at each layer.
|
||||
|
||||
Usage:
|
||||
python tests/test_torch_steppable.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from modeling_qwen3 import Qwen3ForCausalLM
|
||||
from nanovllm.debug.adapters.torch_adapter import TorchSteppable
|
||||
from utils import generate_needle_prompt, check_needle_answer
|
||||
|
||||
# ============================================================
|
||||
# Config
|
||||
# ============================================================
|
||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
||||
INPUT_LEN = 512
|
||||
MAX_NEW_TOKENS = 20
|
||||
DTYPE = torch.float16
|
||||
|
||||
# ============================================================
|
||||
# Load Model
|
||||
# ============================================================
|
||||
print("Loading model...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
|
||||
model = model.to("cuda").eval()
|
||||
|
||||
# ============================================================
|
||||
# Prepare Input (using needle-in-haystack prompt)
|
||||
# ============================================================
|
||||
prompt, expected_answer = generate_needle_prompt(
|
||||
tokenizer,
|
||||
target_length=INPUT_LEN,
|
||||
needle_position=0.5,
|
||||
needle_value="7492",
|
||||
use_chat_template=False,
|
||||
verbose=True,
|
||||
)
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
|
||||
print(f"Input shape: {input_ids.shape}")
|
||||
print(f"Expected answer: {expected_answer}\n")
|
||||
|
||||
# ============================================================
|
||||
# Create Steppable Model (reused for prefill + decode)
|
||||
# ============================================================
|
||||
steppable = TorchSteppable(model)
|
||||
|
||||
# ============================================================
|
||||
# Prefill Phase: Print activation stats
|
||||
# ============================================================
|
||||
print("=" * 85)
|
||||
print("PREFILL PHASE")
|
||||
print("=" * 85)
|
||||
print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
|
||||
print("-" * 85)
|
||||
|
||||
current_ids = input_ids.clone()
|
||||
logits = None
|
||||
|
||||
for bp in steppable.step(current_ids):
|
||||
t = bp.tensor.float()
|
||||
shape_str = str(list(t.shape))
|
||||
print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}")
|
||||
if bp.name == "LM Head":
|
||||
logits = bp.tensor
|
||||
|
||||
# Get first token from prefill
|
||||
next_token_id = logits[0, -1].argmax().item()
|
||||
next_token = tokenizer.decode(next_token_id)
|
||||
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
|
||||
generated_tokens = [next_token]
|
||||
|
||||
# ============================================================
|
||||
# Decode Phase: Only print generated tokens
|
||||
# ============================================================
|
||||
print("\n" + "=" * 85)
|
||||
print("DECODE PHASE")
|
||||
print("=" * 85)
|
||||
print(f"Step 1: {next_token!r}")
|
||||
|
||||
for step in range(2, MAX_NEW_TOKENS + 1):
|
||||
# Forward pass (reuse same steppable)
|
||||
for bp in steppable.step(current_ids):
|
||||
if bp.name == "LM Head":
|
||||
logits = bp.tensor
|
||||
|
||||
# Get next token (greedy)
|
||||
next_token_id = logits[0, -1].argmax().item()
|
||||
next_token = tokenizer.decode(next_token_id)
|
||||
generated_tokens.append(next_token)
|
||||
|
||||
print(f"Step {step:2}: {next_token!r}")
|
||||
|
||||
# Stop if EOS
|
||||
if next_token_id == tokenizer.eos_token_id:
|
||||
print(" (EOS)")
|
||||
break
|
||||
|
||||
# Append to sequence
|
||||
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
|
||||
|
||||
# ============================================================
|
||||
# Result
|
||||
# ============================================================
|
||||
print("\n" + "=" * 85)
|
||||
print("RESULT")
|
||||
print("=" * 85)
|
||||
generated_text = "".join(generated_tokens)
|
||||
print(f"Generated: {generated_text!r}")
|
||||
print(f"Expected: {expected_answer}")
|
||||
print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}")
|
||||
|
||||
print("\ntest_torch_steppable: PASSED")
|
||||
186
tests/utils.py
Normal file
186
tests/utils.py
Normal file
@@ -0,0 +1,186 @@
|
||||
"""
|
||||
Test utilities for nano-vllm.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Needle-in-Haystack Test Utilities
|
||||
# ============================================================
|
||||
|
||||
# Haystack filler paragraphs (various topics to create realistic context)
|
||||
HAYSTACK_PARAGRAPHS = [
|
||||
"The weather today is quite pleasant with clear skies and moderate temperatures. "
|
||||
"Many people are enjoying outdoor activities in the park. "
|
||||
"Birds are singing in the trees and children are playing on the swings. ",
|
||||
|
||||
"In the world of technology, new innovations continue to emerge every day. "
|
||||
"Researchers are working on advanced algorithms and computing systems. "
|
||||
"The future of artificial intelligence looks promising with many breakthroughs. ",
|
||||
|
||||
"The history of human civilization spans thousands of years. "
|
||||
"Ancient cultures developed writing, mathematics, and astronomy. "
|
||||
"Trade routes connected distant lands and facilitated cultural exchange. ",
|
||||
|
||||
"Modern cooking combines traditional techniques with new ingredients. "
|
||||
"Chefs around the world experiment with flavors and presentations. "
|
||||
"Food brings people together and creates memorable experiences. ",
|
||||
|
||||
"The ocean covers more than seventy percent of Earth's surface. "
|
||||
"Marine ecosystems support an incredible diversity of life forms. "
|
||||
"Scientists continue to discover new species in the deep sea. ",
|
||||
|
||||
"Music has been a part of human culture since prehistoric times. "
|
||||
"Different genres evolved across various regions and time periods. "
|
||||
"Today, people can access millions of songs through digital platforms. ",
|
||||
|
||||
"Space exploration has revealed many secrets about our universe. "
|
||||
"Telescopes can observe galaxies billions of light years away. "
|
||||
"Future missions aim to establish human presence on other planets. ",
|
||||
|
||||
"The study of languages reveals patterns in human cognition. "
|
||||
"Linguists analyze grammar, semantics, and phonetics across cultures. "
|
||||
"Language continues to evolve with new words and expressions. ",
|
||||
]
|
||||
|
||||
|
||||
def generate_needle_prompt(
|
||||
tokenizer,
|
||||
target_length: int,
|
||||
needle_position: float = 0.5,
|
||||
needle_value: str = "7492",
|
||||
use_chat_template: bool = True,
|
||||
verbose: bool = True,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Generate a needle-in-haystack prompt of exactly target_length tokens.
|
||||
|
||||
Args:
|
||||
tokenizer: HuggingFace tokenizer for length estimation
|
||||
target_length: Target total sequence length in tokens
|
||||
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
|
||||
needle_value: The secret value to hide in the haystack
|
||||
use_chat_template: Whether to use chat template for instruct models
|
||||
verbose: Whether to print generation info
|
||||
|
||||
Returns:
|
||||
(prompt, expected_answer): The full prompt and the expected needle value
|
||||
"""
|
||||
# The needle sentence
|
||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||
|
||||
# Question text
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||
else:
|
||||
question_text = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||
|
||||
def build_prompt(haystack_parts, needle_idx):
|
||||
"""Build full prompt from haystack parts with needle inserted."""
|
||||
parts = haystack_parts.copy()
|
||||
parts.insert(needle_idx, needle)
|
||||
full_text = "".join(parts)
|
||||
|
||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||
messages = [{"role": "user", "content": f"{full_text}\n\n{question_text}"}]
|
||||
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
else:
|
||||
return full_text + question_text
|
||||
|
||||
def count_tokens(prompt):
|
||||
return len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||
|
||||
def get_needle_idx(parts):
|
||||
idx = int(len(parts) * needle_position)
|
||||
return max(0, min(idx, len(parts)))
|
||||
|
||||
# Phase 1: Build haystack with full paragraphs until we exceed target
|
||||
haystack_parts = []
|
||||
para_idx = 0
|
||||
|
||||
while True:
|
||||
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
test_parts = haystack_parts + [para]
|
||||
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||
|
||||
if count_tokens(prompt) > target_length:
|
||||
break
|
||||
|
||||
haystack_parts.append(para)
|
||||
para_idx += 1
|
||||
|
||||
if para_idx > 10000: # Safety limit
|
||||
break
|
||||
|
||||
# Phase 2: Fine-tune by adding words from next paragraph
|
||||
next_para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||
words = next_para.split()
|
||||
|
||||
best_parts = haystack_parts.copy()
|
||||
best_diff = abs(target_length - count_tokens(build_prompt(haystack_parts, get_needle_idx(haystack_parts))))
|
||||
|
||||
for i in range(1, len(words) + 1):
|
||||
partial = " ".join(words[:i]) + " "
|
||||
test_parts = haystack_parts + [partial]
|
||||
prompt = build_prompt(test_parts, get_needle_idx(test_parts))
|
||||
token_count = count_tokens(prompt)
|
||||
diff = abs(target_length - token_count)
|
||||
|
||||
if diff < best_diff:
|
||||
best_diff = diff
|
||||
best_parts = test_parts.copy()
|
||||
|
||||
if token_count >= target_length:
|
||||
break
|
||||
|
||||
haystack_parts = best_parts
|
||||
|
||||
# Final build
|
||||
needle_idx = get_needle_idx(haystack_parts)
|
||||
prompt = build_prompt(haystack_parts, needle_idx)
|
||||
|
||||
actual_tokens = count_tokens(prompt)
|
||||
if verbose:
|
||||
print(f"[NeedleTest] Target: {target_length}, Actual: {actual_tokens} tokens (diff={actual_tokens - target_length})")
|
||||
|
||||
return prompt, needle_value
|
||||
|
||||
|
||||
def check_needle_answer(output_text: str, expected: str) -> bool:
|
||||
"""Check if the model output contains the expected needle value."""
|
||||
# Clean output - remove special tokens and whitespace
|
||||
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
||||
output_clean = ' '.join(output_clean.split()).lower()
|
||||
expected_clean = expected.strip().lower()
|
||||
|
||||
# Check if expected value appears in output
|
||||
# Also try to find it as a standalone number
|
||||
if expected_clean in output_clean:
|
||||
return True
|
||||
|
||||
# Try to extract numbers and check if expected is among them
|
||||
numbers = re.findall(r'\d+', output_clean)
|
||||
return expected_clean in numbers
|
||||
|
||||
|
||||
def generate_random_token_ids(
|
||||
length: int,
|
||||
vocab_size: int = 10000,
|
||||
seed: int = 42,
|
||||
) -> list:
|
||||
"""
|
||||
Generate random token IDs for testing.
|
||||
|
||||
Args:
|
||||
length: Number of tokens to generate
|
||||
vocab_size: Maximum token ID (exclusive)
|
||||
seed: Random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
List of random token IDs
|
||||
"""
|
||||
from random import randint, seed as set_seed
|
||||
set_seed(seed)
|
||||
return [randint(0, vocab_size - 1) for _ in range(length)]
|
||||
Reference in New Issue
Block a user