diff --git a/CLAUDE.md b/CLAUDE.md index bdbb4d1..5b37e28 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -20,6 +20,80 @@ For sparse attention related content (block sparse attention, MInference, FlexPr - **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096 - **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload +## PyTorch Hooks for Debugging + +### Hook Positions in Qwen3 + +``` +decoder_layer +├── input_layernorm (RMSNorm) +├── self_attn (Qwen3Attention) ← Hook here for attention I/O after o_proj +│ ├── q_proj → q_norm → RoPE +│ ├── k_proj → k_norm → RoPE +│ ├── v_proj +│ ├── attn (Attention) ← Hook here for Q/K/V tensors +│ │ └── FlashAttention / SDPA +│ └── o_proj +├── post_attention_layernorm (RMSNorm) +└── mlp (Qwen3MLP) +``` + +### Hook Types & Data Shapes + +| Hook Position | Type | Captured Data | +|---------------|------|---------------| +| `self_attn` | post | `[batch, seq_len, hidden_size]` - after o_proj | +| `self_attn.attn` | pre | Q,K,V: `[seq_len, num_heads, head_dim]` - after RoPE | +| `self_attn.attn` | post | `[seq_len, num_heads, head_dim]` - before o_proj | + +### Example: Capture Attention Outputs + +```python +storage = {} + +def make_hook(layer_id: int, storage: dict): + def hook(module, inputs, output): + if isinstance(output, tuple): + attn_output = output[0] + else: + attn_output = output + # nanovllm shape: [num_tokens, hidden_size] -> add batch dim + if attn_output.dim() == 2: + attn_output = attn_output.unsqueeze(0) + storage[layer_id] = attn_output.detach().clone() + return hook + +# Register hooks +hooks = [] +for layer_idx, layer in enumerate(model.model.layers): + hooks.append(layer.self_attn.register_forward_hook(make_hook(layer_idx, storage))) + +# Run inference... + +# Cleanup +for hook in hooks: + hook.remove() +``` + +### Alignment Testing + +Use `tests/test_align.py` to compare nanovllm with reference torch implementation: + +```bash +python tests/test_align.py +``` + +Key files: +- `tests/modeling_qwen3.py`: Reference Qwen3 implementation (torch + transformers only) +- `tests/test_align.py`: Compares attention outputs between nanovllm and reference +- `tests/test_needle_ref.py`: Reference needle test using custom Qwen3 + +### Common Pitfalls + +1. **Shape mismatch**: nanovllm uses `[num_tokens, ...]` while torch uses `[batch, seq_len, ...]` +2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj +3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]` + ## CPU Offload System ### Ring Buffer Design