Compare commits
29 Commits
ff8b09cd35
...
tzj/minfer
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6575099a06 | ||
|
|
8fd25d72d7 | ||
|
|
ccf27d3a74 | ||
|
|
0ad86eb449 | ||
|
|
aa953ecb59 | ||
|
|
362f5e575f | ||
|
|
58a06501c1 | ||
|
|
2a6e0a2c02 | ||
|
|
2fe50bab50 | ||
|
|
c99a6f3d3f | ||
|
|
f240903013 | ||
|
|
0e691f2d85 | ||
|
|
edb5273e34 | ||
|
|
690492e074 | ||
|
|
7cc8a394a5 | ||
|
|
535f2037ab | ||
|
|
c7ac39dfbd | ||
|
|
e554d5482b | ||
|
|
247c5312d9 | ||
|
|
054aaff403 | ||
|
|
d623043a3c | ||
|
|
e897380127 | ||
|
|
24096431ed | ||
|
|
772313db8f | ||
|
|
00ed17c640 | ||
|
|
9b52d25866 | ||
|
|
8c3418725b | ||
|
|
b3685c9190 | ||
|
|
6927a75ac3 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -196,3 +196,4 @@ cython_debug/
|
|||||||
|
|
||||||
results/
|
results/
|
||||||
outputs/
|
outputs/
|
||||||
|
.local/
|
||||||
|
|||||||
234
CLAUDE.md
234
CLAUDE.md
@@ -6,10 +6,119 @@ This file provides guidance to Claude Code when working with this repository.
|
|||||||
|
|
||||||
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3 models with CPU offload for long-context inference.
|
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3 models with CPU offload for long-context inference.
|
||||||
|
|
||||||
|
## GPU Mutex for Multi-Instance Debugging
|
||||||
|
|
||||||
|
**IMPORTANT**: When running multiple Claude instances for parallel debugging, only one GPU (cuda:0) is available. Before executing ANY command that uses the GPU (python scripts, benchmarks, tests), Claude MUST:
|
||||||
|
|
||||||
|
1. **Check GPU availability** by running:
|
||||||
|
```bash
|
||||||
|
nvidia-smi --query-compute-apps=pid,name,used_memory --format=csv,noheader
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **If processes are running on GPU**:
|
||||||
|
- Wait and retry every 10 seconds until GPU is free
|
||||||
|
- Use this polling loop:
|
||||||
|
```bash
|
||||||
|
while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do
|
||||||
|
echo "GPU busy, waiting 10s..."
|
||||||
|
sleep 10
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Only proceed** when `nvidia-smi --query-compute-apps=pid --format=csv,noheader` returns empty output
|
||||||
|
|
||||||
|
**Example workflow**:
|
||||||
|
```bash
|
||||||
|
# First check if GPU is in use
|
||||||
|
nvidia-smi --query-compute-apps=pid,name,used_memory --format=csv,noheader
|
||||||
|
|
||||||
|
# If output is empty, proceed with your command
|
||||||
|
python bench_offload.py
|
||||||
|
|
||||||
|
# If output shows processes, wait until they finish
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: This applies to ALL GPU operations including:
|
||||||
|
- Running tests (`python tests/test_*.py`)
|
||||||
|
- Running benchmarks (`python bench*.py`)
|
||||||
|
- Running examples (`python example.py`)
|
||||||
|
- Any script that imports torch/cuda
|
||||||
|
|
||||||
|
## Local Package Installation for Multi-Instance
|
||||||
|
|
||||||
|
**CRITICAL**: After ANY code modification in the `nanovllm/` directory, you MUST reinstall the package before running tests or benchmarks:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e . --prefix=./.local --no-deps
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run with PYTHONPATH:
|
||||||
|
```bash
|
||||||
|
PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python <script.py>
|
||||||
|
```
|
||||||
|
|
||||||
|
**IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances. Instead, use local installation:
|
||||||
|
|
||||||
|
1. **Install to worktree-local directory**:
|
||||||
|
```bash
|
||||||
|
pip install -e . --prefix=./.local --no-deps
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Set PYTHONPATH before running any Python command**:
|
||||||
|
```bash
|
||||||
|
export PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Combined example**:
|
||||||
|
```bash
|
||||||
|
# One-liner for running tests with local package
|
||||||
|
PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python tests/test_needle.py
|
||||||
|
```
|
||||||
|
|
||||||
|
**Note**: The Python version in the path (python3.10) should match your environment.
|
||||||
|
|
||||||
|
**CRITICAL**: After making code changes to `nanovllm/` source files, you MUST reinstall the package for changes to take effect:
|
||||||
|
```bash
|
||||||
|
pip install -e . --prefix=./.local --no-deps
|
||||||
|
```
|
||||||
|
Without reinstallation, Python will use the old cached version and your changes will NOT be reflected!
|
||||||
|
|
||||||
## Sparse Attention
|
## Sparse Attention
|
||||||
|
|
||||||
For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md).
|
For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md).
|
||||||
|
|
||||||
|
### Quest Sparse Policy
|
||||||
|
|
||||||
|
**Files**: `nanovllm/kvcache/sparse/quest.py`, `nanovllm/kvcache/sparse/policy.py`
|
||||||
|
|
||||||
|
Quest policy selects Top-K blocks based on query-key similarity bounds using min/max key metadata.
|
||||||
|
|
||||||
|
**Scoring Mechanism**:
|
||||||
|
```python
|
||||||
|
score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads]
|
||||||
|
score_max = torch.einsum('hd,bhd->bh', q, key_max) # [num_blocks, kv_heads]
|
||||||
|
scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks] ← averaged!
|
||||||
|
```
|
||||||
|
|
||||||
|
**Critical Limitation - No Per-Head Scheduling**:
|
||||||
|
|
||||||
|
The `.mean(dim=-1)` averages scores across all heads, making a **unified** block selection for all heads:
|
||||||
|
|
||||||
|
```
|
||||||
|
Block A: head0 needs (+4), head1 doesn't (-4) → avg = 0 → NOT selected
|
||||||
|
Block B: head0 doesn't (-4), head1 needs (+4) → avg = 0 → NOT selected
|
||||||
|
Block C: both heads moderately need (+2, +2) → avg = +2 → selected
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why Per-Head Scheduling is Infeasible**:
|
||||||
|
1. **Memory Layout**: GPU cache stores all heads together `[block_size, kv_heads, head_dim]`
|
||||||
|
2. **FlashAttention**: Requires complete heads - partial heads cause dimension mismatch
|
||||||
|
3. **Block Granularity**: If any head needs a block, the entire block (all heads) must be loaded
|
||||||
|
|
||||||
|
**Policy Types**:
|
||||||
|
- `FullAttentionPolicy`: `supports_prefill=True, supports_decode=True` - loads all blocks
|
||||||
|
- `QuestPolicy`: `supports_prefill=False, supports_decode=True` - decode-only Top-K selection
|
||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
### Core Components
|
### Core Components
|
||||||
@@ -20,6 +129,74 @@ For sparse attention related content (block sparse attention, MInference, FlexPr
|
|||||||
- **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096
|
- **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096
|
||||||
- **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload
|
- **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload
|
||||||
|
|
||||||
|
## PyTorch Hooks for Debugging
|
||||||
|
|
||||||
|
### Hook Positions in Qwen3
|
||||||
|
|
||||||
|
```
|
||||||
|
decoder_layer
|
||||||
|
├── input_layernorm (RMSNorm)
|
||||||
|
├── self_attn (Qwen3Attention) ← Hook here for attention I/O after o_proj
|
||||||
|
│ ├── q_proj → q_norm → RoPE
|
||||||
|
│ ├── k_proj → k_norm → RoPE
|
||||||
|
│ ├── v_proj
|
||||||
|
│ ├── attn (Attention) ← Hook here for Q/K/V tensors
|
||||||
|
│ │ └── FlashAttention / SDPA
|
||||||
|
│ └── o_proj
|
||||||
|
├── post_attention_layernorm (RMSNorm)
|
||||||
|
└── mlp (Qwen3MLP)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hook Types & Data Shapes
|
||||||
|
|
||||||
|
| Hook Position | Type | Captured Data |
|
||||||
|
|---------------|------|---------------|
|
||||||
|
| `self_attn` | post | `[batch, seq_len, hidden_size]` - after o_proj |
|
||||||
|
| `self_attn.attn` | pre | Q,K,V: `[seq_len, num_heads, head_dim]` - after RoPE |
|
||||||
|
| `self_attn.attn` | post | `[seq_len, num_heads, head_dim]` - before o_proj |
|
||||||
|
|
||||||
|
### Example: Capture Attention Outputs
|
||||||
|
|
||||||
|
```python
|
||||||
|
storage = {}
|
||||||
|
|
||||||
|
def make_hook(layer_id: int, storage: dict):
|
||||||
|
def hook(module, inputs, output):
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
attn_output = output[0]
|
||||||
|
else:
|
||||||
|
attn_output = output
|
||||||
|
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
|
||||||
|
if attn_output.dim() == 2:
|
||||||
|
attn_output = attn_output.unsqueeze(0)
|
||||||
|
storage[layer_id] = attn_output.detach().clone()
|
||||||
|
return hook
|
||||||
|
|
||||||
|
# Register hooks
|
||||||
|
hooks = []
|
||||||
|
for layer_idx, layer in enumerate(model.model.layers):
|
||||||
|
hooks.append(layer.self_attn.register_forward_hook(make_hook(layer_idx, storage)))
|
||||||
|
|
||||||
|
# Run inference...
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Reference Implementation
|
||||||
|
|
||||||
|
Key files:
|
||||||
|
- `tests/modeling_qwen3.py`: Reference Qwen3 implementation (torch + transformers only)
|
||||||
|
- `tests/test_needle_ref.py`: Reference needle test using custom Qwen3
|
||||||
|
- `tests/test_needle.py`: Needle-in-haystack test for nanovllm
|
||||||
|
|
||||||
|
### Common Pitfalls
|
||||||
|
|
||||||
|
1. **Shape mismatch**: nanovllm uses `[num_tokens, ...]` while torch uses `[batch, seq_len, ...]`
|
||||||
|
2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj
|
||||||
|
3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]`
|
||||||
|
|
||||||
## CPU Offload System
|
## CPU Offload System
|
||||||
|
|
||||||
### Ring Buffer Design
|
### Ring Buffer Design
|
||||||
@@ -105,7 +282,6 @@ memcpy_2d_async(gpu_buf, cpu_cache[:, block_id], dpitch, spitch, width, height,
|
|||||||
**Files**:
|
**Files**:
|
||||||
- `csrc/sgdma_kernel.cu`, `csrc/sgdma.cpp`: CUDA extension
|
- `csrc/sgdma_kernel.cu`, `csrc/sgdma.cpp`: CUDA extension
|
||||||
- `nanovllm/comm/sgdma.py`: Python API
|
- `nanovllm/comm/sgdma.py`: Python API
|
||||||
- `tests/test_sgdma.py`: Standalone benchmark
|
|
||||||
- `kvcache/offload_engine.py`: Integration (4 methods updated)
|
- `kvcache/offload_engine.py`: Integration (4 methods updated)
|
||||||
|
|
||||||
### Integration Details
|
### Integration Details
|
||||||
@@ -210,25 +386,59 @@ def _merge_output_kernel(...):
|
|||||||
- Total GPU time: ~1,343 ms
|
- Total GPU time: ~1,343 ms
|
||||||
- **Overall speedup with Triton**: 1.67x
|
- **Overall speedup with Triton**: 1.67x
|
||||||
|
|
||||||
### Correctness Verification
|
|
||||||
|
|
||||||
**Test**: `tests/test_chunked_attention.py`
|
|
||||||
- 12 test cases (6 configs × 2 dtypes)
|
|
||||||
- All tests PASS with max error < 0.01
|
|
||||||
- float16: max_diff=0.000488, mean_diff~0.00001
|
|
||||||
- bfloat16: max_diff=0.003906, mean_diff~0.0001
|
|
||||||
|
|
||||||
### Key Files
|
### Key Files
|
||||||
|
|
||||||
- `nanovllm/kvcache/chunked_attention.py`: Triton kernels + merge function
|
- `nanovllm/kvcache/chunked_attention.py`: Triton kernels + merge function
|
||||||
- `tests/test_chunked_attention.py`: Correctness tests
|
|
||||||
- `tests/test_attention_offload.py`: Performance profiling
|
## Known Issues and Fixes
|
||||||
|
|
||||||
|
### Partial Last Block Bug (FIXED ✓)
|
||||||
|
|
||||||
|
**Problem**: When prefill token count is not an exact multiple of `block_size`, decode outputs garbage.
|
||||||
|
|
||||||
|
**Root Cause**: `_chunked_decode_attention` calculated `last_block_valid_tokens` using `len(seq) - 1`, which increases during decode. But CPU blocks are fixed after prefill!
|
||||||
|
|
||||||
|
```python
|
||||||
|
# BUG: len(seq) increases each decode step
|
||||||
|
total_prefill_tokens = len(seq) - 1 # Wrong!
|
||||||
|
last_block_valid_tokens = total_prefill_tokens % block_size # Reads garbage from CPU
|
||||||
|
```
|
||||||
|
|
||||||
|
**Fix**: Cache original prefill length in `HybridKVCacheManager.get_prefill_len()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# CORRECT: Use cached prefill length
|
||||||
|
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Fixed value
|
||||||
|
```
|
||||||
|
|
||||||
|
**Files Modified**:
|
||||||
|
- `nanovllm/kvcache/hybrid_manager.py`: Added `_prefill_len` dict and `get_prefill_len()` method
|
||||||
|
- `nanovllm/layers/attention.py`: Use `get_prefill_len()` instead of `len(seq) - 1`
|
||||||
|
|
||||||
|
### Block Size 4096 Race Condition (FIXED ✓)
|
||||||
|
|
||||||
|
**Problem**: `block_size=4096` with multiple chunks produced `index_copy_(): index out of bounds` CUDA error during Chunk 2 processing.
|
||||||
|
|
||||||
|
**Root Cause**: Race condition between default stream and compute stream. In `_prepare_chunked_offload_chunk()`, `slot_mapping` tensor was created with `non_blocking=True` H2D transfer on the default stream. However, `store_kvcache` runs on `compute_stream`. Without synchronization, `compute_stream` could use `slot_mapping` before its transfer completed, causing corrupted indices.
|
||||||
|
|
||||||
|
**Fix** (in `attention.py`):
|
||||||
|
```python
|
||||||
|
if is_chunked_offload:
|
||||||
|
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
||||||
|
if k_cache.numel() and v_cache.numel():
|
||||||
|
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||||
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Tested block sizes**: 512, 1024, 4096, 8192 - all pass.
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
| Parameter | Default | Notes |
|
| Parameter | Default | Notes |
|
||||||
|-----------|---------|-------|
|
|-----------|---------|-------|
|
||||||
| `kvcache_block_size` | 4096 | Tokens per block |
|
| `kvcache_block_size` | 1024 | Tokens per block (4096 now works after race condition fix) |
|
||||||
| `max_num_batched_tokens` | 16384 | Set = max_model_len for long context |
|
| `max_num_batched_tokens` | 16384 | Set = max_model_len for long context |
|
||||||
| `gpu_memory_utilization` | 0.9 | GPU memory fraction |
|
| `gpu_memory_utilization` | 0.9 | GPU memory fraction |
|
||||||
| `enable_cpu_offload` | False | Enable for long context |
|
| `enable_cpu_offload` | False | Enable for long context |
|
||||||
|
|||||||
103
DEBUG_SUMMARY.md
Normal file
103
DEBUG_SUMMARY.md
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
# Chunked Prefill Bug Debug Summary
|
||||||
|
|
||||||
|
## Problem
|
||||||
|
`test_needle.py --enable-offload --input-len 8192` fails with garbage output.
|
||||||
|
|
||||||
|
The model generates completely wrong tokens instead of the expected "7492".
|
||||||
|
|
||||||
|
## Investigation Progress
|
||||||
|
|
||||||
|
### 1. Stream Synchronization Fix (Completed)
|
||||||
|
- Replaced Triton `store_kvcache` kernel with pure PyTorch operations
|
||||||
|
- Moved `store_kvcache` to `compute_stream` in chunked prefill mode
|
||||||
|
- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload
|
||||||
|
- Added sync: `default_stream.wait_stream(compute_stream)` before return
|
||||||
|
|
||||||
|
### 2. KV Cache Alignment Verification (Completed)
|
||||||
|
Created alignment tests to compare K/V tensors between torch reference and nanovllm:
|
||||||
|
|
||||||
|
**RoPE Alignment:**
|
||||||
|
- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0)
|
||||||
|
- Confirmed RoPE is NOT the cause of the bug
|
||||||
|
|
||||||
|
**K/V Cache Alignment (Chunk 0):**
|
||||||
|
- Cosine similarity: ~1.0 for all layers
|
||||||
|
- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision)
|
||||||
|
- Mean diff: < 0.001
|
||||||
|
- **Conclusion: K/V cache offload is working correctly**
|
||||||
|
|
||||||
|
### 3. Layer Output Divergence Analysis (Completed)
|
||||||
|
Created per-chunk layer output comparison:
|
||||||
|
|
||||||
|
**Chunk 0 (tokens 0-4096):**
|
||||||
|
- All layers pass with excellent cosine similarity (0.999+)
|
||||||
|
- Max diff grows in later layers but within acceptable range
|
||||||
|
|
||||||
|
**Chunk 1 (tokens 4096-8192):**
|
||||||
|
- Layers 0-19: OK (cosine ~1.0)
|
||||||
|
- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114)
|
||||||
|
- Divergence correlates with later transformer layers
|
||||||
|
|
||||||
|
### 4. Critical Discovery: Single-Chunk Offload Also Fails
|
||||||
|
**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled.
|
||||||
|
|
||||||
|
```
|
||||||
|
# Without offload: PASSES
|
||||||
|
python tests/test_needle.py --input-len 2048
|
||||||
|
# Output: "7492" (correct)
|
||||||
|
|
||||||
|
# With offload: FAILS
|
||||||
|
python tests/test_needle.py --enable-offload --input-len 2048
|
||||||
|
# Output: "The Ble White Th G Lopsiswin..." (garbage)
|
||||||
|
```
|
||||||
|
|
||||||
|
**This proves the bug is NOT in:**
|
||||||
|
- Chunked attention logic (merge_attention_outputs)
|
||||||
|
- Multi-chunk KV loading
|
||||||
|
- Ring buffer pipeline
|
||||||
|
|
||||||
|
**The bug IS in:**
|
||||||
|
- The decode path when CPU offload is enabled
|
||||||
|
- How prefilled KV is loaded/used during decode
|
||||||
|
|
||||||
|
### 5. Decode Path Analysis (In Progress)
|
||||||
|
The decode path in CPU offload mode:
|
||||||
|
1. Prefill writes KV to GPU, offloads to CPU
|
||||||
|
2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline`
|
||||||
|
3. Attend to prefilled KV + accumulated decode tokens
|
||||||
|
4. Merge results
|
||||||
|
|
||||||
|
**Observations:**
|
||||||
|
- `prefilled_blocks` set is empty after decode (should contain block IDs)
|
||||||
|
- CPU cache has valid data (reasonable mean/std values)
|
||||||
|
- Decode buffer has zeros (decode tokens not being stored correctly?)
|
||||||
|
|
||||||
|
## Current Status
|
||||||
|
|
||||||
|
### Working
|
||||||
|
- Stream synchronization fixes
|
||||||
|
- K/V cache offload to CPU (verified alignment)
|
||||||
|
- RoPE implementation
|
||||||
|
- Chunked prefill attention for first chunk
|
||||||
|
|
||||||
|
### Not Working
|
||||||
|
- Decode with CPU offload (even for single-chunk inputs)
|
||||||
|
- Multi-chunk attention (divergence in later layers for chunk 1)
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
1. Debug why `prefilled_blocks` is empty after decode
|
||||||
|
2. Check if decode path correctly loads KV from CPU
|
||||||
|
3. Verify decode buffer is being written correctly
|
||||||
|
4. Compare decode attention outputs between offload and non-offload modes
|
||||||
|
|
||||||
|
## Key Files
|
||||||
|
- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths
|
||||||
|
- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine
|
||||||
|
- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks`
|
||||||
|
- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration
|
||||||
|
|
||||||
|
## Hypothesis
|
||||||
|
The decode path fails because:
|
||||||
|
1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty
|
||||||
|
2. OR the decode attention is not correctly loading/using the prefilled KV from CPU
|
||||||
|
3. OR there's a stream synchronization issue specific to decode path
|
||||||
60
bench.py
60
bench.py
@@ -5,7 +5,7 @@ from nanovllm import LLM, SamplingParams
|
|||||||
|
|
||||||
|
|
||||||
def bench_decode(llm, num_seqs, input_len, output_len):
|
def bench_decode(llm, num_seqs, input_len, output_len):
|
||||||
"""Benchmark decode performance (original test)"""
|
"""Benchmark decode performance"""
|
||||||
seed(0)
|
seed(0)
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=output_len)
|
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=output_len)
|
||||||
@@ -13,9 +13,14 @@ def bench_decode(llm, num_seqs, input_len, output_len):
|
|||||||
t = time.time()
|
t = time.time()
|
||||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
t = time.time() - t
|
t = time.time() - t
|
||||||
total_output_tokens = num_seqs * output_len
|
|
||||||
throughput = total_output_tokens / t
|
# Calculate metrics
|
||||||
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
prefill_tokens = num_seqs * input_len
|
||||||
|
decode_tokens = num_seqs * output_len
|
||||||
|
decode_throughput = decode_tokens / t
|
||||||
|
|
||||||
|
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
|
||||||
|
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
|
||||||
|
|
||||||
|
|
||||||
def bench_prefill(llm, num_seqs, input_len):
|
def bench_prefill(llm, num_seqs, input_len):
|
||||||
@@ -35,32 +40,49 @@ def bench_prefill(llm, num_seqs, input_len):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(description="Benchmark nanovllm GPU performance")
|
||||||
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
||||||
parser.add_argument("--output-len", type=int, default=128, help="Output length in tokens")
|
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
||||||
|
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
||||||
|
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||||
|
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
# Note: Qwen3-4B-Instruct-2507 max_position_embeddings = 262144
|
max_len = args.max_len
|
||||||
max_len = 131072 # 128K tokens
|
|
||||||
llm = LLM(path, enforce_eager=False, max_model_len=max_len, max_num_batched_tokens=max_len)
|
print(f"\n[nanovllm GPU] max_len={max_len}")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
path,
|
||||||
|
enforce_eager=False,
|
||||||
|
max_model_len=max_len,
|
||||||
|
max_num_batched_tokens=max_len,
|
||||||
|
)
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
llm.generate(["Benchmark: "], SamplingParams())
|
print("\nWarming up...")
|
||||||
|
llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10))
|
||||||
|
|
||||||
# Default input lengths based on max_len
|
# Default input lengths
|
||||||
prefill_input_len = args.input_len if args.input_len else max_len - 1
|
prefill_input_len = args.input_len if args.input_len else max_len - 1
|
||||||
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
|
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
|
||||||
|
|
||||||
print("=" * 60)
|
# Determine which benchmarks to run
|
||||||
print("Prefill Benchmark (GPU)")
|
run_prefill = not args.bench_decode or args.bench_all
|
||||||
print("=" * 60)
|
run_decode = args.bench_decode or args.bench_all
|
||||||
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
|
|
||||||
|
|
||||||
# print("=" * 60)
|
if run_prefill:
|
||||||
# print("Decode Benchmark (GPU)")
|
print("\n" + "=" * 60)
|
||||||
# print("=" * 60)
|
print("Prefill Benchmark (nanovllm GPU)")
|
||||||
# bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
|
print("=" * 60)
|
||||||
|
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
|
||||||
|
|
||||||
|
if run_decode:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Decode Benchmark (nanovllm GPU)")
|
||||||
|
print("=" * 60)
|
||||||
|
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
143
bench_offload.py
143
bench_offload.py
@@ -3,14 +3,9 @@ import time
|
|||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
|
|
||||||
# Import sparse policy classes
|
|
||||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
|
||||||
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
|
||||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
|
||||||
|
|
||||||
|
|
||||||
def bench_decode(llm, num_seqs, input_len, output_len):
|
def bench_decode(llm, num_seqs, input_len, output_len):
|
||||||
"""Benchmark decode performance (original test)"""
|
"""Benchmark decode performance"""
|
||||||
seed(0)
|
seed(0)
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=output_len)
|
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=output_len)
|
||||||
@@ -18,9 +13,17 @@ def bench_decode(llm, num_seqs, input_len, output_len):
|
|||||||
t = time.time()
|
t = time.time()
|
||||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
t = time.time() - t
|
t = time.time() - t
|
||||||
total_output_tokens = num_seqs * output_len
|
|
||||||
throughput = total_output_tokens / t
|
# Calculate metrics
|
||||||
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
prefill_tokens = num_seqs * input_len
|
||||||
|
decode_tokens = num_seqs * output_len
|
||||||
|
|
||||||
|
# Approximate: assume prefill takes ~input_len/prefill_speed, rest is decode
|
||||||
|
# For more accurate measurement, we'd need internal timing
|
||||||
|
decode_throughput = decode_tokens / t # This includes prefill time, so it's a lower bound
|
||||||
|
|
||||||
|
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
|
||||||
|
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
|
||||||
|
|
||||||
|
|
||||||
def bench_prefill(llm, num_seqs, input_len):
|
def bench_prefill(llm, num_seqs, input_len):
|
||||||
@@ -38,102 +41,70 @@ def bench_prefill(llm, num_seqs, input_len):
|
|||||||
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||||
|
|
||||||
|
|
||||||
def setup_quest_policy(llm, topk_blocks=8, threshold_blocks=4):
|
|
||||||
"""
|
|
||||||
Setup Quest sparse policy for decode phase.
|
|
||||||
|
|
||||||
Uses HybridPolicy: Full attention for prefill, Quest Top-K for decode.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
kvcache_manager = llm.model_runner.kvcache_manager
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
|
||||||
|
|
||||||
# Get model parameters from offload engine
|
|
||||||
num_layers = offload_engine.num_layers
|
|
||||||
num_kv_heads = offload_engine.num_kv_heads
|
|
||||||
head_dim = offload_engine.head_dim
|
|
||||||
num_cpu_blocks = kvcache_manager.num_cpu_blocks
|
|
||||||
dtype = offload_engine.k_cache_cpu.dtype
|
|
||||||
|
|
||||||
print(f"Setting up Quest policy:")
|
|
||||||
print(f" num_layers={num_layers}, num_kv_heads={num_kv_heads}, head_dim={head_dim}")
|
|
||||||
print(f" num_cpu_blocks={num_cpu_blocks}, dtype={dtype}")
|
|
||||||
print(f" topk_blocks={topk_blocks}, threshold_blocks={threshold_blocks}")
|
|
||||||
|
|
||||||
# Create BlockMetadataManager for storing min/max keys
|
|
||||||
metadata = BlockMetadataManager(
|
|
||||||
num_blocks=num_cpu_blocks,
|
|
||||||
num_layers=num_layers,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
head_dim=head_dim,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create Quest policy for decode
|
|
||||||
quest_config = QuestConfig(
|
|
||||||
topk_blocks=topk_blocks,
|
|
||||||
threshold_blocks=threshold_blocks,
|
|
||||||
)
|
|
||||||
quest_policy = QuestPolicy(quest_config, metadata)
|
|
||||||
|
|
||||||
# Create Hybrid policy: Full for prefill, Quest for decode
|
|
||||||
hybrid_policy = HybridPolicy(
|
|
||||||
prefill_policy=FullAttentionPolicy(),
|
|
||||||
decode_policy=quest_policy,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set the policy
|
|
||||||
kvcache_manager.set_sparse_policy(hybrid_policy)
|
|
||||||
print(f" Policy set: HybridPolicy(prefill=Full, decode=Quest)")
|
|
||||||
|
|
||||||
return hybrid_policy
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
from nanovllm.config import SparsePolicyType
|
||||||
parser.add_argument("--no-sparse", action="store_true", help="Disable sparse attention (baseline)")
|
|
||||||
parser.add_argument("--topk", type=int, default=8, help="Top-K blocks for Quest")
|
parser = argparse.ArgumentParser(description="Benchmark CPU offload performance")
|
||||||
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens (default: max_len - 1 for prefill, max_len - output_len for decode)")
|
parser.add_argument("--enable-quest", action="store_true", help="Enable Quest sparse attention for decode")
|
||||||
parser.add_argument("--output-len", type=int, default=128, help="Output length in tokens")
|
parser.add_argument("--topk", type=int, default=16, help="Top-K blocks for Quest (default: 16)")
|
||||||
|
parser.add_argument("--threshold", type=int, default=4, help="Apply sparse only when blocks > threshold (default: 4)")
|
||||||
|
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
||||||
|
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
||||||
|
parser.add_argument("--num-gpu-blocks", type=int, default=6, help="Number of GPU blocks (default: 6)")
|
||||||
|
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
||||||
|
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||||
|
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
# Note: Qwen3-4B-Instruct-2507 max_position_embeddings = 262144
|
max_len = args.max_len
|
||||||
max_len = 131072 # 128K tokens
|
|
||||||
|
# Setup policy configuration
|
||||||
|
if args.enable_quest:
|
||||||
|
sparse_policy = SparsePolicyType.QUEST
|
||||||
|
print(f"\n[Quest Sparse Attention] topk={args.topk}, threshold={args.threshold}")
|
||||||
|
else:
|
||||||
|
sparse_policy = SparsePolicyType.FULL
|
||||||
|
print("\n[Full Attention] baseline (no sparse)")
|
||||||
|
|
||||||
|
print(f"[Config] max_len={max_len}, num_gpu_blocks={args.num_gpu_blocks}")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
path,
|
path,
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
max_model_len=max_len,
|
max_model_len=max_len,
|
||||||
max_num_batched_tokens=max_len,
|
max_num_batched_tokens=max_len,
|
||||||
enable_cpu_offload=True,
|
enable_cpu_offload=True,
|
||||||
num_gpu_blocks=6, # Small GPU buffer for offload testing
|
num_gpu_blocks=args.num_gpu_blocks,
|
||||||
|
sparse_policy=sparse_policy,
|
||||||
|
sparse_topk_blocks=args.topk,
|
||||||
|
sparse_threshold_blocks=args.threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not args.no_sparse:
|
|
||||||
# Setup Quest policy for decode (Top-K blocks, apply when > 4 blocks)
|
|
||||||
setup_quest_policy(llm, topk_blocks=args.topk, threshold_blocks=4)
|
|
||||||
print(f"\n[Quest Sparse Attention] topk={args.topk}")
|
|
||||||
else:
|
|
||||||
print("\n[Full Attention] No sparse policy (baseline)")
|
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
llm.generate(["Benchmark: "], SamplingParams())
|
print("\nWarming up...")
|
||||||
|
llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10))
|
||||||
|
|
||||||
# Default input lengths based on max_len
|
# Default input lengths
|
||||||
prefill_input_len = args.input_len if args.input_len else max_len - 1
|
prefill_input_len = args.input_len if args.input_len else max_len - 1
|
||||||
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
|
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
|
||||||
|
|
||||||
print("=" * 60)
|
# Determine which benchmarks to run
|
||||||
print("Prefill Benchmark (CPU Offload)")
|
run_prefill = not args.bench_decode or args.bench_all
|
||||||
print("=" * 60)
|
run_decode = args.bench_decode or args.bench_all
|
||||||
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
|
|
||||||
|
|
||||||
# print("=" * 60)
|
if run_prefill:
|
||||||
# print("Decode Benchmark (CPU Offload)")
|
print("\n" + "=" * 60)
|
||||||
# print("=" * 60)
|
print("Prefill Benchmark (CPU Offload)")
|
||||||
# bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
|
print("=" * 60)
|
||||||
|
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
|
||||||
|
|
||||||
|
if run_decode:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Decode Benchmark (CPU Offload)")
|
||||||
|
print("=" * 60)
|
||||||
|
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from vllm import LLM, SamplingParams
|
|||||||
|
|
||||||
|
|
||||||
def bench_decode(llm, num_seqs, input_len, output_len):
|
def bench_decode(llm, num_seqs, input_len, output_len):
|
||||||
"""Benchmark decode performance (original test)"""
|
"""Benchmark decode performance"""
|
||||||
seed(0)
|
seed(0)
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=output_len)
|
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=output_len)
|
||||||
@@ -15,9 +15,14 @@ def bench_decode(llm, num_seqs, input_len, output_len):
|
|||||||
t = time.time()
|
t = time.time()
|
||||||
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
||||||
t = time.time() - t
|
t = time.time() - t
|
||||||
total_output_tokens = num_seqs * output_len
|
|
||||||
throughput = total_output_tokens / t
|
# Calculate metrics
|
||||||
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {total_output_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
prefill_tokens = num_seqs * input_len
|
||||||
|
decode_tokens = num_seqs * output_len
|
||||||
|
decode_throughput = decode_tokens / t
|
||||||
|
|
||||||
|
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
|
||||||
|
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
|
||||||
|
|
||||||
|
|
||||||
def bench_prefill(llm, num_seqs, input_len):
|
def bench_prefill(llm, num_seqs, input_len):
|
||||||
@@ -38,32 +43,50 @@ def bench_prefill(llm, num_seqs, input_len):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser(description="Benchmark vLLM performance (for comparison)")
|
||||||
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
||||||
parser.add_argument("--output-len", type=int, default=128, help="Output length in tokens")
|
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
||||||
|
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
||||||
|
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||||
|
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
# Note: Qwen3-4B-Instruct-2507 max_position_embeddings = 262144
|
max_len = args.max_len
|
||||||
max_len = 131072 # 128K tokens
|
|
||||||
llm = LLM(path, enforce_eager=False, max_model_len=max_len, max_num_seqs=128, gpu_memory_utilization=0.9)
|
print(f"\n[vLLM] max_len={max_len}")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
path,
|
||||||
|
enforce_eager=False,
|
||||||
|
max_model_len=max_len,
|
||||||
|
max_num_seqs=128,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
llm.generate([dict(prompt_token_ids=[0])], SamplingParams())
|
print("\nWarming up...")
|
||||||
|
llm.generate([dict(prompt_token_ids=[0, 1, 2])], SamplingParams(max_tokens=10))
|
||||||
|
|
||||||
# Default input lengths based on max_len
|
# Default input lengths
|
||||||
prefill_input_len = args.input_len if args.input_len else max_len - 1
|
prefill_input_len = args.input_len if args.input_len else max_len - 1
|
||||||
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
|
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
|
||||||
|
|
||||||
print("=" * 60)
|
# Determine which benchmarks to run
|
||||||
print("Prefill Benchmark (vLLM)")
|
run_prefill = not args.bench_decode or args.bench_all
|
||||||
print("=" * 60)
|
run_decode = args.bench_decode or args.bench_all
|
||||||
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
|
|
||||||
|
|
||||||
# print("=" * 60)
|
if run_prefill:
|
||||||
# print("Decode Benchmark (vLLM)")
|
print("\n" + "=" * 60)
|
||||||
# print("=" * 60)
|
print("Prefill Benchmark (vLLM)")
|
||||||
# bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
|
print("=" * 60)
|
||||||
|
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
|
||||||
|
|
||||||
|
if run_decode:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("Decode Benchmark (vLLM)")
|
||||||
|
print("=" * 60)
|
||||||
|
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,9 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class SparsePolicyType(Enum):
|
||||||
|
"""Sparse attention policy types."""
|
||||||
|
FULL = auto() # No sparse attention (load all blocks)
|
||||||
|
QUEST = auto() # Query-aware Top-K block selection (decode only)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
model: str
|
model: str
|
||||||
@@ -15,7 +22,7 @@ class Config:
|
|||||||
enforce_eager: bool = False
|
enforce_eager: bool = False
|
||||||
hf_config: AutoConfig | None = None
|
hf_config: AutoConfig | None = None
|
||||||
eos: int = -1
|
eos: int = -1
|
||||||
kvcache_block_size: int = 4096
|
kvcache_block_size: int = 1024
|
||||||
num_kvcache_blocks: int = -1
|
num_kvcache_blocks: int = -1
|
||||||
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
|
dtype: str | None = None # "float16", "bfloat16", or None (use model default)
|
||||||
|
|
||||||
@@ -30,9 +37,9 @@ class Config:
|
|||||||
num_cpu_kvcache_blocks: int = -1
|
num_cpu_kvcache_blocks: int = -1
|
||||||
|
|
||||||
# Sparse attention configuration
|
# Sparse attention configuration
|
||||||
sparse_policy: str | None = None # "vertical_slash", "quest", "streaming_llm", or None
|
# Quest: decode-only sparse attention with Top-K block selection
|
||||||
sparse_num_sink_blocks: int = 1 # Number of sink blocks for sparse patterns
|
# FULL: no sparse attention (load all blocks)
|
||||||
sparse_local_window_blocks: int = 2 # Local window size for VerticalSlash
|
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
|
||||||
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
||||||
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
|
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
|
||||||
|
|
||||||
|
|||||||
49
nanovllm/debug/__init__.py
Normal file
49
nanovllm/debug/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""
|
||||||
|
Breakpoint debugging tools for aligning nanovllm with reference implementations.
|
||||||
|
|
||||||
|
This module provides a generator-based breakpoint aligner that enables step-by-step
|
||||||
|
comparison between nanovllm and torch reference model outputs.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
|
||||||
|
>>> from tests.modeling_qwen3 import Qwen3ForCausalLM
|
||||||
|
>>>
|
||||||
|
>>> # Load models
|
||||||
|
>>> torch_model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch.float16)
|
||||||
|
>>> nanovllm_model = ... # Your nanovllm model
|
||||||
|
>>>
|
||||||
|
>>> # Create adapters
|
||||||
|
>>> ref = TorchSteppable(torch_model)
|
||||||
|
>>> test = NanovllmSteppable(nanovllm_model)
|
||||||
|
>>>
|
||||||
|
>>> # Run alignment
|
||||||
|
>>> aligner = BreakpointAligner(ref, test)
|
||||||
|
>>> result = aligner.align(input_ids)
|
||||||
|
>>> print(result)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from .breakpoints import BreakpointType, Breakpoint
|
||||||
|
from .comparator import TensorComparator, ComparisonResult
|
||||||
|
from .aligner import BreakpointAligner, AlignmentResult
|
||||||
|
from .adapters import SteppableModel, TorchSteppable, NanovllmSteppable
|
||||||
|
from .utils import setup_prefill_context, setup_decode_context, cleanup_context
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Core classes
|
||||||
|
"BreakpointAligner",
|
||||||
|
"AlignmentResult",
|
||||||
|
# Breakpoints
|
||||||
|
"BreakpointType",
|
||||||
|
"Breakpoint",
|
||||||
|
# Comparator
|
||||||
|
"TensorComparator",
|
||||||
|
"ComparisonResult",
|
||||||
|
# Adapters
|
||||||
|
"SteppableModel",
|
||||||
|
"TorchSteppable",
|
||||||
|
"NanovllmSteppable",
|
||||||
|
# Utils
|
||||||
|
"setup_prefill_context",
|
||||||
|
"setup_decode_context",
|
||||||
|
"cleanup_context",
|
||||||
|
]
|
||||||
11
nanovllm/debug/adapters/__init__.py
Normal file
11
nanovllm/debug/adapters/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""Model adapters for breakpoint alignment."""
|
||||||
|
|
||||||
|
from .base import SteppableModel
|
||||||
|
from .torch_adapter import TorchSteppable
|
||||||
|
from .nanovllm_adapter import NanovllmSteppable
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SteppableModel",
|
||||||
|
"TorchSteppable",
|
||||||
|
"NanovllmSteppable",
|
||||||
|
]
|
||||||
59
nanovllm/debug/adapters/base.py
Normal file
59
nanovllm/debug/adapters/base.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""Base class for steppable model adapters."""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Generator, Set, Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..breakpoints import Breakpoint, BreakpointType
|
||||||
|
|
||||||
|
|
||||||
|
class SteppableModel(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for models that can yield at breakpoints.
|
||||||
|
|
||||||
|
Subclasses implement the step() method as a generator that yields
|
||||||
|
Breakpoint objects at each enabled breakpoint during forward pass.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, enabled_breakpoints: Optional[Set[BreakpointType]] = None):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
enabled_breakpoints: Set of breakpoint types to yield at.
|
||||||
|
If None, yields at all breakpoints.
|
||||||
|
"""
|
||||||
|
self.enabled_breakpoints = enabled_breakpoints
|
||||||
|
|
||||||
|
def is_enabled(self, bp_type: BreakpointType) -> bool:
|
||||||
|
"""Check if a breakpoint type is enabled."""
|
||||||
|
if self.enabled_breakpoints is None:
|
||||||
|
return True
|
||||||
|
return bp_type in self.enabled_breakpoints
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: Optional[torch.Tensor] = None,
|
||||||
|
is_prefill: bool = True,
|
||||||
|
) -> Generator[Breakpoint, None, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Generator that yields Breakpoint objects at enabled breakpoints.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input token IDs
|
||||||
|
positions: Position IDs (optional, auto-generated if None)
|
||||||
|
is_prefill: True for prefill phase, False for decode
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Breakpoint objects at each enabled checkpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final output tensor (logits)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def num_layers(self) -> int:
|
||||||
|
"""Return the number of decoder layers."""
|
||||||
|
pass
|
||||||
235
nanovllm/debug/adapters/nanovllm_adapter.py
Normal file
235
nanovllm/debug/adapters/nanovllm_adapter.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""Nanovllm model adapter for breakpoint alignment."""
|
||||||
|
|
||||||
|
from typing import Generator, Set, Optional, Dict, Any, List
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from nanovllm.utils.context import set_context, reset_context
|
||||||
|
from ..breakpoints import Breakpoint, BreakpointType
|
||||||
|
from .base import SteppableModel
|
||||||
|
|
||||||
|
|
||||||
|
class NanovllmSteppable(SteppableModel):
|
||||||
|
"""
|
||||||
|
Steppable adapter for nanovllm Qwen3 implementation.
|
||||||
|
|
||||||
|
Uses PyTorch hooks to capture intermediate values during forward pass,
|
||||||
|
then yields them as breakpoints after execution completes.
|
||||||
|
|
||||||
|
Key challenges handled:
|
||||||
|
1. Shape difference: nanovllm uses [num_tokens, hidden] vs [batch, seq, hidden]
|
||||||
|
2. Context-based attention: must call set_context() before forward
|
||||||
|
3. Fused operations: decoder layer returns (hidden_states, residual) tuple
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model: Qwen3ForCausalLM from nanovllm
|
||||||
|
enabled_breakpoints: Set of breakpoint types to yield at
|
||||||
|
"""
|
||||||
|
super().__init__(enabled_breakpoints)
|
||||||
|
self.model = model
|
||||||
|
self.model.eval()
|
||||||
|
self._hooks: List[Any] = []
|
||||||
|
self._captured: Dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_layers(self) -> int:
|
||||||
|
return len(self.model.model.layers)
|
||||||
|
|
||||||
|
def _register_hooks(self):
|
||||||
|
"""Register forward hooks on all relevant modules."""
|
||||||
|
self._hooks = []
|
||||||
|
self._captured = {}
|
||||||
|
|
||||||
|
# Hook for embedding output
|
||||||
|
def embed_hook(module, input, output):
|
||||||
|
self._captured["embed"] = output.detach().clone()
|
||||||
|
|
||||||
|
self._hooks.append(
|
||||||
|
self.model.model.embed_tokens.register_forward_hook(embed_hook)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hooks for each decoder layer
|
||||||
|
for layer_idx in range(self.num_layers):
|
||||||
|
layer = self.model.model.layers[layer_idx]
|
||||||
|
|
||||||
|
def make_layer_hook(idx):
|
||||||
|
def hook(module, input, output):
|
||||||
|
# Decoder layer returns (hidden_states, residual)
|
||||||
|
# hidden_states is MLP output, residual is accumulated residual
|
||||||
|
# To match torch reference, we need hidden_states + residual
|
||||||
|
if isinstance(output, tuple) and len(output) >= 2:
|
||||||
|
hidden_states, residual = output[0], output[1]
|
||||||
|
full_output = hidden_states + residual
|
||||||
|
else:
|
||||||
|
full_output = output
|
||||||
|
self._captured[f"layer_{idx}"] = full_output.detach().clone()
|
||||||
|
return hook
|
||||||
|
|
||||||
|
self._hooks.append(
|
||||||
|
layer.register_forward_hook(make_layer_hook(layer_idx))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hook for final norm
|
||||||
|
def final_norm_hook(module, input, output):
|
||||||
|
# Final norm returns (hidden_states, _) for fused add
|
||||||
|
hidden_states = output[0] if isinstance(output, tuple) else output
|
||||||
|
self._captured["final_norm"] = hidden_states.detach().clone()
|
||||||
|
|
||||||
|
self._hooks.append(
|
||||||
|
self.model.model.norm.register_forward_hook(final_norm_hook)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Hook for lm_head
|
||||||
|
def lm_head_hook(module, input, output):
|
||||||
|
self._captured["lm_head"] = output.detach().clone()
|
||||||
|
|
||||||
|
self._hooks.append(
|
||||||
|
self.model.lm_head.register_forward_hook(lm_head_hook)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _remove_hooks(self):
|
||||||
|
"""Remove all registered hooks."""
|
||||||
|
for hook in self._hooks:
|
||||||
|
hook.remove()
|
||||||
|
self._hooks = []
|
||||||
|
|
||||||
|
def _setup_context(self, seq_len: int, device: torch.device, is_prefill: bool):
|
||||||
|
"""
|
||||||
|
Set up nanovllm context for forward pass.
|
||||||
|
|
||||||
|
For alignment testing, we use simple context without real KV cache.
|
||||||
|
"""
|
||||||
|
if is_prefill:
|
||||||
|
# Prefill: process all tokens at once
|
||||||
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
|
||||||
|
# Use -1 for slot_mapping to skip KV cache writes
|
||||||
|
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
set_context(
|
||||||
|
is_prefill=True,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=seq_len,
|
||||||
|
max_seqlen_k=seq_len,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
is_chunked_prefill=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Decode: single token generation
|
||||||
|
# For decode, we need context_lens and block_tables
|
||||||
|
# For alignment testing without real KV cache, we use minimal setup
|
||||||
|
context_lens = torch.tensor([seq_len - 1], dtype=torch.int32, device=device)
|
||||||
|
# Single token slot
|
||||||
|
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
|
||||||
|
# Empty block tables (no KV cache)
|
||||||
|
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
set_context(
|
||||||
|
is_prefill=False,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
context_lens=context_lens,
|
||||||
|
block_tables=block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalize nanovllm tensor shape to [batch, seq_len, ...].
|
||||||
|
|
||||||
|
nanovllm uses [num_tokens, ...] format without batch dimension.
|
||||||
|
We add batch dimension for comparison with torch model.
|
||||||
|
"""
|
||||||
|
if tensor.dim() == 2: # [num_tokens, hidden_size]
|
||||||
|
return tensor.unsqueeze(0)
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: Optional[torch.Tensor] = None,
|
||||||
|
is_prefill: bool = True,
|
||||||
|
) -> Generator[Breakpoint, None, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Execute nanovllm forward pass with hooks to capture breakpoints.
|
||||||
|
|
||||||
|
Unlike the torch adapter which manually steps through each component,
|
||||||
|
we run the full forward pass and collect captured values afterward.
|
||||||
|
"""
|
||||||
|
# Ensure 1D for nanovllm (it expects [num_tokens])
|
||||||
|
if input_ids.dim() == 2:
|
||||||
|
input_ids = input_ids.squeeze(0)
|
||||||
|
|
||||||
|
seq_len = input_ids.numel()
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
# Generate position IDs if not provided
|
||||||
|
if positions is None:
|
||||||
|
positions = torch.arange(seq_len, device=device)
|
||||||
|
elif positions.dim() == 2:
|
||||||
|
positions = positions.squeeze(0)
|
||||||
|
|
||||||
|
# Register hooks
|
||||||
|
self._register_hooks()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Setup context for attention
|
||||||
|
self._setup_context(seq_len, device, is_prefill)
|
||||||
|
|
||||||
|
# Run forward pass (hooks capture everything)
|
||||||
|
with torch.no_grad():
|
||||||
|
hidden_states = self.model(input_ids, positions)
|
||||||
|
logits = self.model.compute_logits(hidden_states)
|
||||||
|
|
||||||
|
reset_context()
|
||||||
|
|
||||||
|
# Yield breakpoints in order from captured data
|
||||||
|
|
||||||
|
# EMBEDDING
|
||||||
|
if self.is_enabled(BreakpointType.EMBEDDING) and "embed" in self._captured:
|
||||||
|
yield Breakpoint(
|
||||||
|
bp_type=BreakpointType.EMBEDDING,
|
||||||
|
layer_idx=None,
|
||||||
|
tensor=self._normalize_tensor(self._captured["embed"]),
|
||||||
|
name="Embedding",
|
||||||
|
)
|
||||||
|
|
||||||
|
# LAYER_OUTPUT for each layer
|
||||||
|
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
|
||||||
|
for layer_idx in range(self.num_layers):
|
||||||
|
key = f"layer_{layer_idx}"
|
||||||
|
if key in self._captured:
|
||||||
|
yield Breakpoint(
|
||||||
|
bp_type=BreakpointType.LAYER_OUTPUT,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
tensor=self._normalize_tensor(self._captured[key]),
|
||||||
|
name=f"Layer {layer_idx}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# FINAL_NORM
|
||||||
|
if self.is_enabled(BreakpointType.FINAL_NORM) and "final_norm" in self._captured:
|
||||||
|
yield Breakpoint(
|
||||||
|
bp_type=BreakpointType.FINAL_NORM,
|
||||||
|
layer_idx=None,
|
||||||
|
tensor=self._normalize_tensor(self._captured["final_norm"]),
|
||||||
|
name="Final Norm",
|
||||||
|
)
|
||||||
|
|
||||||
|
# LM_HEAD
|
||||||
|
if self.is_enabled(BreakpointType.LM_HEAD) and "lm_head" in self._captured:
|
||||||
|
yield Breakpoint(
|
||||||
|
bp_type=BreakpointType.LM_HEAD,
|
||||||
|
layer_idx=None,
|
||||||
|
tensor=self._normalize_tensor(self._captured["lm_head"]),
|
||||||
|
name="LM Head",
|
||||||
|
)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
finally:
|
||||||
|
self._remove_hooks()
|
||||||
|
self._captured = {}
|
||||||
119
nanovllm/debug/adapters/torch_adapter.py
Normal file
119
nanovllm/debug/adapters/torch_adapter.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""Torch reference model adapter for breakpoint alignment."""
|
||||||
|
|
||||||
|
from typing import Generator, Set, Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from ..breakpoints import Breakpoint, BreakpointType
|
||||||
|
from .base import SteppableModel
|
||||||
|
|
||||||
|
|
||||||
|
class TorchSteppable(SteppableModel):
|
||||||
|
"""
|
||||||
|
Steppable adapter for the torch reference Qwen3 implementation.
|
||||||
|
|
||||||
|
Wraps tests/modeling_qwen3.py Qwen3ForCausalLM model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
enabled_breakpoints: Optional[Set[BreakpointType]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model: Qwen3ForCausalLM from tests/modeling_qwen3.py
|
||||||
|
enabled_breakpoints: Set of breakpoint types to yield at
|
||||||
|
"""
|
||||||
|
super().__init__(enabled_breakpoints)
|
||||||
|
self.model = model
|
||||||
|
self.model.eval()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def num_layers(self) -> int:
|
||||||
|
return len(self.model.model.layers)
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: Optional[torch.Tensor] = None,
|
||||||
|
is_prefill: bool = True,
|
||||||
|
) -> Generator[Breakpoint, None, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Generator that manually steps through the torch model.
|
||||||
|
|
||||||
|
The torch model uses [batch, seq_len, hidden_size] shapes.
|
||||||
|
"""
|
||||||
|
# Ensure batch dimension
|
||||||
|
if input_ids.dim() == 1:
|
||||||
|
input_ids = input_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
device = input_ids.device
|
||||||
|
|
||||||
|
# Generate position IDs if not provided
|
||||||
|
if positions is None:
|
||||||
|
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
|
||||||
|
elif positions.dim() == 1:
|
||||||
|
positions = positions.unsqueeze(0)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
# === EMBEDDING ===
|
||||||
|
hidden_states = self.model.model.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if self.is_enabled(BreakpointType.EMBEDDING):
|
||||||
|
yield Breakpoint(
|
||||||
|
bp_type=BreakpointType.EMBEDDING,
|
||||||
|
layer_idx=None,
|
||||||
|
tensor=hidden_states.detach().clone(),
|
||||||
|
name="Embedding",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create causal attention mask
|
||||||
|
causal_mask = torch.triu(
|
||||||
|
torch.full((seq_len, seq_len), float("-inf"), device=device),
|
||||||
|
diagonal=1,
|
||||||
|
)
|
||||||
|
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
# === DECODER LAYERS ===
|
||||||
|
for layer_idx, layer in enumerate(self.model.model.layers):
|
||||||
|
hidden_states, _, _ = layer(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_ids=positions,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_value=None,
|
||||||
|
use_cache=False,
|
||||||
|
output_qkv=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.is_enabled(BreakpointType.LAYER_OUTPUT):
|
||||||
|
yield Breakpoint(
|
||||||
|
bp_type=BreakpointType.LAYER_OUTPUT,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
tensor=hidden_states.detach().clone(),
|
||||||
|
name=f"Layer {layer_idx}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# === FINAL NORM ===
|
||||||
|
hidden_states = self.model.model.norm(hidden_states)
|
||||||
|
|
||||||
|
if self.is_enabled(BreakpointType.FINAL_NORM):
|
||||||
|
yield Breakpoint(
|
||||||
|
bp_type=BreakpointType.FINAL_NORM,
|
||||||
|
layer_idx=None,
|
||||||
|
tensor=hidden_states.detach().clone(),
|
||||||
|
name="Final Norm",
|
||||||
|
)
|
||||||
|
|
||||||
|
# === LM HEAD ===
|
||||||
|
logits = self.model.lm_head(hidden_states)
|
||||||
|
|
||||||
|
if self.is_enabled(BreakpointType.LM_HEAD):
|
||||||
|
yield Breakpoint(
|
||||||
|
bp_type=BreakpointType.LM_HEAD,
|
||||||
|
layer_idx=None,
|
||||||
|
tensor=logits.detach().clone(),
|
||||||
|
name="LM Head",
|
||||||
|
)
|
||||||
|
|
||||||
|
return logits
|
||||||
211
nanovllm/debug/aligner.py
Normal file
211
nanovllm/debug/aligner.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
"""Breakpoint aligner for comparing model outputs."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Tuple, Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .breakpoints import Breakpoint
|
||||||
|
from .comparator import TensorComparator, ComparisonResult
|
||||||
|
from .adapters.base import SteppableModel
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AlignmentResult:
|
||||||
|
"""Result of an alignment test."""
|
||||||
|
passed: bool
|
||||||
|
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = field(default_factory=list)
|
||||||
|
failed_at: Optional[Breakpoint] = None
|
||||||
|
message: str = ""
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
passed_count = sum(1 for _, _, c in self.all_comparisons if c.passed)
|
||||||
|
total = len(self.all_comparisons)
|
||||||
|
status = "PASSED" if self.passed else "FAILED"
|
||||||
|
return f"AlignmentResult({status}, {passed_count}/{total} breakpoints passed)"
|
||||||
|
|
||||||
|
|
||||||
|
class BreakpointAligner:
|
||||||
|
"""
|
||||||
|
Orchestrates alternating execution of reference and test models,
|
||||||
|
comparing outputs at each breakpoint.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> from nanovllm.debug import BreakpointAligner, TorchSteppable, NanovllmSteppable
|
||||||
|
>>> ref = TorchSteppable(torch_model)
|
||||||
|
>>> test = NanovllmSteppable(nanovllm_model)
|
||||||
|
>>> aligner = BreakpointAligner(ref, test)
|
||||||
|
>>> result = aligner.align(input_ids)
|
||||||
|
>>> print(result)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ref_model: SteppableModel,
|
||||||
|
test_model: SteppableModel,
|
||||||
|
comparator: Optional[TensorComparator] = None,
|
||||||
|
stop_on_error: bool = True,
|
||||||
|
verbose: bool = True,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
ref_model: Reference (torch) steppable model
|
||||||
|
test_model: Test (nanovllm) steppable model
|
||||||
|
comparator: Tensor comparator instance (uses default if None)
|
||||||
|
stop_on_error: If True, stop at first mismatch
|
||||||
|
verbose: If True, print comparison results
|
||||||
|
"""
|
||||||
|
self.ref_model = ref_model
|
||||||
|
self.test_model = test_model
|
||||||
|
self.comparator = comparator or TensorComparator()
|
||||||
|
self.stop_on_error = stop_on_error
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
def align(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: Optional[torch.Tensor] = None,
|
||||||
|
is_prefill: bool = True,
|
||||||
|
) -> AlignmentResult:
|
||||||
|
"""
|
||||||
|
Run both models with same input, comparing at each breakpoint.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_ids: Input token IDs
|
||||||
|
positions: Position IDs (optional)
|
||||||
|
is_prefill: True for prefill phase, False for decode
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AlignmentResult with pass/fail status and details
|
||||||
|
"""
|
||||||
|
all_comparisons: List[Tuple[Breakpoint, Breakpoint, ComparisonResult]] = []
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
phase = "prefill" if is_prefill else "decode"
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Alignment Test ({phase})")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
|
||||||
|
# Start both generators
|
||||||
|
ref_gen = self.ref_model.step(input_ids, positions, is_prefill)
|
||||||
|
test_gen = self.test_model.step(input_ids, positions, is_prefill)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
# Get next breakpoint from reference
|
||||||
|
try:
|
||||||
|
ref_bp = next(ref_gen)
|
||||||
|
except StopIteration:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get corresponding breakpoint from test
|
||||||
|
try:
|
||||||
|
test_bp = next(test_gen)
|
||||||
|
except StopIteration:
|
||||||
|
if self.verbose:
|
||||||
|
print(f"Test model ended early at {ref_bp.name}")
|
||||||
|
return AlignmentResult(
|
||||||
|
passed=False,
|
||||||
|
all_comparisons=all_comparisons,
|
||||||
|
failed_at=ref_bp,
|
||||||
|
message=f"Test model ended early at {ref_bp.name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify breakpoints match
|
||||||
|
if ref_bp.bp_type != test_bp.bp_type:
|
||||||
|
msg = f"Breakpoint type mismatch: {ref_bp.bp_type} vs {test_bp.bp_type}"
|
||||||
|
if self.verbose:
|
||||||
|
print(msg)
|
||||||
|
return AlignmentResult(
|
||||||
|
passed=False,
|
||||||
|
all_comparisons=all_comparisons,
|
||||||
|
failed_at=ref_bp,
|
||||||
|
message=msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ref_bp.layer_idx != test_bp.layer_idx:
|
||||||
|
msg = f"Layer index mismatch: {ref_bp.layer_idx} vs {test_bp.layer_idx}"
|
||||||
|
if self.verbose:
|
||||||
|
print(msg)
|
||||||
|
return AlignmentResult(
|
||||||
|
passed=False,
|
||||||
|
all_comparisons=all_comparisons,
|
||||||
|
failed_at=ref_bp,
|
||||||
|
message=msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize shapes for comparison
|
||||||
|
ref_t = ref_bp.normalize_shape()
|
||||||
|
test_t = test_bp.normalize_shape()
|
||||||
|
|
||||||
|
# Handle shape mismatches
|
||||||
|
if ref_t.shape != test_t.shape:
|
||||||
|
if self.verbose:
|
||||||
|
print(f"[{ref_bp.name}] Shape mismatch: ref={ref_t.shape} vs test={test_t.shape}")
|
||||||
|
|
||||||
|
# Try to reshape if element count matches
|
||||||
|
if ref_t.numel() == test_t.numel():
|
||||||
|
test_t = test_t.view(ref_t.shape)
|
||||||
|
else:
|
||||||
|
msg = f"Shape mismatch at {ref_bp.name}: {ref_t.shape} vs {test_t.shape}"
|
||||||
|
return AlignmentResult(
|
||||||
|
passed=False,
|
||||||
|
all_comparisons=all_comparisons,
|
||||||
|
failed_at=ref_bp,
|
||||||
|
message=msg,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compare tensors
|
||||||
|
result = self.comparator.compare(ref_t, test_t, ref_bp.name)
|
||||||
|
all_comparisons.append((ref_bp, test_bp, result))
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
status = "\u2713" if result.passed else "\u2717"
|
||||||
|
print(f"{status} [{ref_bp.name}] cos={result.cosine_similarity:.6f}, max_diff={result.max_abs_diff:.2e}")
|
||||||
|
|
||||||
|
if not result.passed and self.stop_on_error:
|
||||||
|
if self.verbose:
|
||||||
|
print(f"\nStopped at {ref_bp.name} (stop_on_error=True)")
|
||||||
|
print(result.message)
|
||||||
|
return AlignmentResult(
|
||||||
|
passed=False,
|
||||||
|
all_comparisons=all_comparisons,
|
||||||
|
failed_at=ref_bp,
|
||||||
|
message=f"Alignment failed at {ref_bp.name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for extra test breakpoints
|
||||||
|
try:
|
||||||
|
extra_bp = next(test_gen)
|
||||||
|
msg = f"Test model has extra breakpoints starting at {extra_bp.name}"
|
||||||
|
if self.verbose:
|
||||||
|
print(msg)
|
||||||
|
return AlignmentResult(
|
||||||
|
passed=False,
|
||||||
|
all_comparisons=all_comparisons,
|
||||||
|
message=msg,
|
||||||
|
)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
msg = f"Exception during alignment: {str(e)}"
|
||||||
|
if self.verbose:
|
||||||
|
print(msg)
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
all_passed = all(comp[2].passed for comp in all_comparisons)
|
||||||
|
passed_count = sum(1 for _, _, c in all_comparisons if c.passed)
|
||||||
|
total = len(all_comparisons)
|
||||||
|
|
||||||
|
if self.verbose:
|
||||||
|
print(f"{'='*60}")
|
||||||
|
status = "PASSED" if all_passed else "FAILED"
|
||||||
|
print(f"Result: {status} ({passed_count}/{total} breakpoints)")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
return AlignmentResult(
|
||||||
|
passed=all_passed,
|
||||||
|
all_comparisons=all_comparisons,
|
||||||
|
message="All breakpoints aligned" if all_passed else "Some breakpoints failed",
|
||||||
|
)
|
||||||
39
nanovllm/debug/breakpoints.py
Normal file
39
nanovllm/debug/breakpoints.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
"""Breakpoint types and data structures for alignment debugging."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class BreakpointType(Enum):
|
||||||
|
"""Types of breakpoints in the model forward pass."""
|
||||||
|
EMBEDDING = auto() # After embed_tokens
|
||||||
|
LAYER_OUTPUT = auto() # After each decoder layer
|
||||||
|
FINAL_NORM = auto() # After final RMSNorm
|
||||||
|
LM_HEAD = auto() # After lm_head (logits)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Breakpoint:
|
||||||
|
"""A captured breakpoint with tensor data."""
|
||||||
|
bp_type: BreakpointType
|
||||||
|
layer_idx: Optional[int] # None for EMBEDDING, FINAL_NORM, LM_HEAD
|
||||||
|
tensor: torch.Tensor
|
||||||
|
name: str
|
||||||
|
|
||||||
|
def normalize_shape(self) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Normalize tensor shape for comparison.
|
||||||
|
|
||||||
|
nanovllm uses [num_tokens, hidden_size] while torch uses
|
||||||
|
[batch, seq_len, hidden_size]. This adds a batch dimension
|
||||||
|
to 2D tensors for comparison.
|
||||||
|
"""
|
||||||
|
if self.tensor.dim() == 2:
|
||||||
|
return self.tensor.unsqueeze(0)
|
||||||
|
return self.tensor
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
shape_str = "x".join(str(d) for d in self.tensor.shape)
|
||||||
|
return f"Breakpoint({self.name}, shape={shape_str}, dtype={self.tensor.dtype})"
|
||||||
94
nanovllm/debug/comparator.py
Normal file
94
nanovllm/debug/comparator.py
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
"""Tensor comparison utilities for alignment debugging."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ComparisonResult:
|
||||||
|
"""Result of comparing two tensors."""
|
||||||
|
passed: bool
|
||||||
|
cosine_similarity: float
|
||||||
|
max_abs_diff: float
|
||||||
|
mean_abs_diff: float
|
||||||
|
message: str
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
status = "\u2713" if self.passed else "\u2717"
|
||||||
|
return f"{status} cos={self.cosine_similarity:.6f}, max_diff={self.max_abs_diff:.2e}"
|
||||||
|
|
||||||
|
|
||||||
|
class TensorComparator:
|
||||||
|
"""Compares tensors using cosine similarity and absolute differences."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
cosine_threshold: float = 0.999,
|
||||||
|
max_diff_threshold: float = 0.1,
|
||||||
|
mean_diff_threshold: float = 0.01,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
cosine_threshold: Minimum cosine similarity to pass (0-1)
|
||||||
|
max_diff_threshold: Maximum allowed absolute difference
|
||||||
|
mean_diff_threshold: Maximum allowed mean absolute difference
|
||||||
|
"""
|
||||||
|
self.cosine_threshold = cosine_threshold
|
||||||
|
self.max_diff_threshold = max_diff_threshold
|
||||||
|
self.mean_diff_threshold = mean_diff_threshold
|
||||||
|
|
||||||
|
def compare(
|
||||||
|
self,
|
||||||
|
ref: torch.Tensor,
|
||||||
|
test: torch.Tensor,
|
||||||
|
name: str = "",
|
||||||
|
) -> ComparisonResult:
|
||||||
|
"""
|
||||||
|
Compare two tensors and return detailed result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ref: Reference tensor
|
||||||
|
test: Test tensor
|
||||||
|
name: Name for the comparison (used in message)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ComparisonResult with pass/fail status and metrics
|
||||||
|
"""
|
||||||
|
# Convert to float32 for comparison
|
||||||
|
ref_f = ref.float().flatten()
|
||||||
|
test_f = test.float().flatten()
|
||||||
|
|
||||||
|
# Cosine similarity
|
||||||
|
cos_sim = F.cosine_similarity(
|
||||||
|
ref_f.unsqueeze(0),
|
||||||
|
test_f.unsqueeze(0)
|
||||||
|
).item()
|
||||||
|
|
||||||
|
# Absolute differences
|
||||||
|
diff = (ref.float() - test.float()).abs()
|
||||||
|
max_diff = diff.max().item()
|
||||||
|
mean_diff = diff.mean().item()
|
||||||
|
|
||||||
|
# Check thresholds
|
||||||
|
passed = (
|
||||||
|
cos_sim >= self.cosine_threshold and
|
||||||
|
max_diff <= self.max_diff_threshold and
|
||||||
|
mean_diff <= self.mean_diff_threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
status = "PASS" if passed else "FAIL"
|
||||||
|
message = (
|
||||||
|
f"[{name}] {status}\n"
|
||||||
|
f" Cosine Similarity: {cos_sim:.6f} (threshold: {self.cosine_threshold})\n"
|
||||||
|
f" Max Abs Diff: {max_diff:.6f} (threshold: {self.max_diff_threshold})\n"
|
||||||
|
f" Mean Abs Diff: {mean_diff:.6f} (threshold: {self.mean_diff_threshold})"
|
||||||
|
)
|
||||||
|
|
||||||
|
return ComparisonResult(
|
||||||
|
passed=passed,
|
||||||
|
cosine_similarity=cos_sim,
|
||||||
|
max_abs_diff=max_diff,
|
||||||
|
mean_abs_diff=mean_diff,
|
||||||
|
message=message,
|
||||||
|
)
|
||||||
51
nanovllm/debug/utils.py
Normal file
51
nanovllm/debug/utils.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""Utility functions for breakpoint alignment debugging."""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from nanovllm.utils.context import set_context, reset_context
|
||||||
|
|
||||||
|
|
||||||
|
def setup_prefill_context(seq_len: int, device: torch.device):
|
||||||
|
"""
|
||||||
|
Set up nanovllm context for prefill alignment testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq_len: Sequence length
|
||||||
|
device: Target device
|
||||||
|
"""
|
||||||
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device)
|
||||||
|
slot_mapping = torch.full((seq_len,), -1, dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
set_context(
|
||||||
|
is_prefill=True,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=seq_len,
|
||||||
|
max_seqlen_k=seq_len,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
is_chunked_prefill=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_decode_context(context_len: int, device: torch.device):
|
||||||
|
"""
|
||||||
|
Set up nanovllm context for decode alignment testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
context_len: Context length (number of previous tokens)
|
||||||
|
device: Target device
|
||||||
|
"""
|
||||||
|
context_lens = torch.tensor([context_len], dtype=torch.int32, device=device)
|
||||||
|
slot_mapping = torch.tensor([-1], dtype=torch.int32, device=device)
|
||||||
|
block_tables = torch.zeros((1, 1), dtype=torch.int32, device=device)
|
||||||
|
|
||||||
|
set_context(
|
||||||
|
is_prefill=False,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
context_lens=context_lens,
|
||||||
|
block_tables=block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_context():
|
||||||
|
"""Reset nanovllm context after alignment testing."""
|
||||||
|
reset_context()
|
||||||
@@ -62,6 +62,8 @@ class LLMEngine:
|
|||||||
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
token_ids = self.model_runner.call("run", seqs, is_prefill)
|
||||||
self.scheduler.postprocess(seqs, token_ids)
|
self.scheduler.postprocess(seqs, token_ids)
|
||||||
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
|
||||||
|
|
||||||
|
#> Calculate number of tokens processed
|
||||||
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
|
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
|
||||||
return outputs, num_tokens
|
return outputs, num_tokens
|
||||||
|
|
||||||
|
|||||||
@@ -35,7 +35,10 @@ class ModelRunner:
|
|||||||
self.model = Qwen3ForCausalLM(hf_config)
|
self.model = Qwen3ForCausalLM(hf_config)
|
||||||
load_model(self.model, config.model)
|
load_model(self.model, config.model)
|
||||||
self.sampler = GreedySampler()
|
self.sampler = GreedySampler()
|
||||||
|
|
||||||
|
#> Disable warmup for debugging
|
||||||
self.warmup_model()
|
self.warmup_model()
|
||||||
|
|
||||||
self.allocate_kv_cache()
|
self.allocate_kv_cache()
|
||||||
if not self.enforce_eager:
|
if not self.enforce_eager:
|
||||||
self.capture_cudagraph()
|
self.capture_cudagraph()
|
||||||
@@ -59,7 +62,7 @@ class ModelRunner:
|
|||||||
self.shm.unlink()
|
self.shm.unlink()
|
||||||
if not self.enforce_eager:
|
if not self.enforce_eager:
|
||||||
del self.graphs, self.graph_pool
|
del self.graphs, self.graph_pool
|
||||||
torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
dist.destroy_process_group()
|
dist.destroy_process_group()
|
||||||
|
|
||||||
def loop(self):
|
def loop(self):
|
||||||
@@ -153,6 +156,22 @@ class ModelRunner:
|
|||||||
dtype=hf_config.torch_dtype,
|
dtype=hf_config.torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize sparse policy if manager has one (CPU offload mode)
|
||||||
|
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
|
||||||
|
self.kvcache_manager.sparse_policy.initialize(
|
||||||
|
num_layers=hf_config.num_hidden_layers,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
num_cpu_blocks=config.num_cpu_kvcache_blocks,
|
||||||
|
dtype=hf_config.torch_dtype,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Sparse policy initialized: {config.sparse_policy.name} "
|
||||||
|
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
||||||
|
)
|
||||||
|
|
||||||
# Log KV cache allocation info with detailed per-token breakdown
|
# Log KV cache allocation info with detailed per-token breakdown
|
||||||
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
gpu_memory_mb = config.num_gpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||||
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
cpu_memory_mb = config.num_cpu_kvcache_blocks * block_bytes / (1024 ** 2)
|
||||||
@@ -194,7 +213,7 @@ class ModelRunner:
|
|||||||
f"block_size={self.block_size}"
|
f"block_size={self.block_size}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Bind layer caches to attention modules and set layer_id
|
#> Bind layer caches to attention modules and set layer_id
|
||||||
layer_id = 0
|
layer_id = 0
|
||||||
for module in self.model.modules():
|
for module in self.model.modules():
|
||||||
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
if hasattr(module, "k_cache") and hasattr(module, "v_cache"):
|
||||||
@@ -379,7 +398,7 @@ class ModelRunner:
|
|||||||
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
return self.model.compute_logits(graph_vars["outputs"][:bs])
|
||||||
|
|
||||||
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]:
|
||||||
# Check if Chunked Offload mode should be used (all blocks on CPU)
|
#> Check if Chunked Offload mode should be used (all blocks on CPU)
|
||||||
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
|
if hasattr(self, 'kvcache_manager') and hasattr(self.kvcache_manager, 'get_all_cpu_blocks'):
|
||||||
use_chunked_offload = self._should_use_chunked_offload(seqs, is_prefill)
|
use_chunked_offload = self._should_use_chunked_offload(seqs, is_prefill)
|
||||||
if use_chunked_offload:
|
if use_chunked_offload:
|
||||||
@@ -388,6 +407,7 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
return self.run_chunked_offload_decode(seqs)
|
return self.run_chunked_offload_decode(seqs)
|
||||||
|
|
||||||
|
#> Following Code will not use Chunked Offload mode
|
||||||
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs)
|
||||||
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
|
||||||
logits = self.run_model(input_ids, positions, is_prefill)
|
logits = self.run_model(input_ids, positions, is_prefill)
|
||||||
@@ -435,8 +455,6 @@ class ModelRunner:
|
|||||||
3. After each chunk, offload from ring buffer slot to CPU
|
3. After each chunk, offload from ring buffer slot to CPU
|
||||||
4. All N-1 other slots are used to load previous chunks for attention
|
4. All N-1 other slots are used to load previous chunks for attention
|
||||||
"""
|
"""
|
||||||
import sys
|
|
||||||
|
|
||||||
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
|
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
|
||||||
seq = seqs[0]
|
seq = seqs[0]
|
||||||
|
|
||||||
@@ -446,10 +464,9 @@ class ModelRunner:
|
|||||||
|
|
||||||
total_tokens = len(seq)
|
total_tokens = len(seq)
|
||||||
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
|
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
|
||||||
print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
|
logger.debug(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
|
||||||
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
|
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
|
||||||
f"total_chunks={num_chunks}",
|
f"total_chunks={num_chunks}")
|
||||||
file=sys.stderr)
|
|
||||||
|
|
||||||
chunk_idx = 0
|
chunk_idx = 0
|
||||||
logits = None
|
logits = None
|
||||||
@@ -468,9 +485,8 @@ class ModelRunner:
|
|||||||
# CPU block index for this chunk
|
# CPU block index for this chunk
|
||||||
block_idx = chunk_idx
|
block_idx = chunk_idx
|
||||||
|
|
||||||
print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
|
logger.debug(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
|
||||||
f"write_slot={write_slot}",
|
f"write_slot={write_slot}")
|
||||||
file=sys.stderr)
|
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
input_ids, positions = self._prepare_chunked_offload_chunk(
|
input_ids, positions = self._prepare_chunked_offload_chunk(
|
||||||
@@ -480,7 +496,7 @@ class ModelRunner:
|
|||||||
if input_ids.numel() == 0:
|
if input_ids.numel() == 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
# Run model forward
|
#> Run model forward
|
||||||
logits = self.run_model(input_ids, positions, is_prefill=True)
|
logits = self.run_model(input_ids, positions, is_prefill=True)
|
||||||
reset_context()
|
reset_context()
|
||||||
|
|
||||||
@@ -489,27 +505,17 @@ class ModelRunner:
|
|||||||
logical_id = seq.block_table[block_idx]
|
logical_id = seq.block_table[block_idx]
|
||||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||||
|
|
||||||
# NOTE: Per-layer offloading is now done in attention.forward
|
# NOTE: Per-layer async offloading is now done in attention.forward
|
||||||
# Each layer offloads its KV to CPU immediately after computing attention.
|
# Each layer offloads from its own prefill buffer - no waiting required!
|
||||||
# We just need to wait for the last offload to complete before reusing the slot.
|
# The sparse policy hook is called in offload_prefill_buffer_async.
|
||||||
if block_idx < len(cpu_block_ids):
|
|
||||||
# TODO: Sparse policy hook needs update for new GPU cache architecture
|
|
||||||
# The GPU cache no longer has layer dimension, so we can't access
|
|
||||||
# k_cache_gpu[layer_id, write_slot]. Sparse policy should be called
|
|
||||||
# in attention.forward after per-layer offload.
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Wait for offload to complete before next chunk
|
|
||||||
# (slot will be reused after N chunks)
|
|
||||||
offload_engine.wait_slot_offload(write_slot)
|
|
||||||
|
|
||||||
processed_tokens = chunk_end
|
processed_tokens = chunk_end
|
||||||
chunk_idx += 1
|
chunk_idx += 1
|
||||||
|
|
||||||
# Wait for all offloads to complete
|
# Wait for all async prefill offloads to complete
|
||||||
offload_engine.wait_all_offload_done()
|
offload_engine.wait_all_prefill_offloads()
|
||||||
|
|
||||||
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
|
logger.debug(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks")
|
||||||
|
|
||||||
# Sample from last logits
|
# Sample from last logits
|
||||||
# For chunked prefill, ParallelLMHead automatically selects last position's logits
|
# For chunked prefill, ParallelLMHead automatically selects last position's logits
|
||||||
@@ -570,14 +576,15 @@ class ModelRunner:
|
|||||||
|
|
||||||
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
|
def run_chunked_offload_decode(self, seqs: list[Sequence]) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Run decode with ring buffer (CPU is primary storage).
|
Run decode with cross-layer pipeline (CPU is primary storage).
|
||||||
|
|
||||||
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
|
All KV is on CPU. Uses decode_slot (slot[0]) to write new KV.
|
||||||
Other slots (slots[1:]) are used to load previous KV chunks via pipeline.
|
Optimized with cross-layer pipeline: Layer N's data is loaded while
|
||||||
New token's KV is written to decode_slot then offloaded to CPU only when block is full.
|
Layer N-1 computes, achieving transfer/compute overlap.
|
||||||
|
|
||||||
Key: decode_slot is dedicated to writing new KV, never used for loading.
|
Key: decode_slot is dedicated to writing new KV, never used for loading.
|
||||||
Optimization: Batch offloads - only offload when block is full, attend to all accumulated tokens.
|
Optimization: Cross-layer pipeline reduces effective latency by overlapping
|
||||||
|
H2D transfers with attention computation across layers.
|
||||||
"""
|
"""
|
||||||
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
|
assert len(seqs) == 1, "Ring buffer decode only supports single sequence"
|
||||||
seq = seqs[0]
|
seq = seqs[0]
|
||||||
@@ -598,6 +605,12 @@ class ModelRunner:
|
|||||||
# Get decode start position for accumulated token tracking
|
# Get decode start position for accumulated token tracking
|
||||||
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
|
decode_start_pos = self.kvcache_manager.get_decode_start_pos(seq)
|
||||||
|
|
||||||
|
# Get prefilled CPU blocks for pipeline initialization
|
||||||
|
cpu_block_table = self.kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
|
# Start cross-layer pipeline (preloads Layer 0's data)
|
||||||
|
offload_engine.start_decode_pipeline(cpu_block_table)
|
||||||
|
|
||||||
# Set up context for chunked decode
|
# Set up context for chunked decode
|
||||||
set_context(
|
set_context(
|
||||||
is_prefill=False,
|
is_prefill=False,
|
||||||
@@ -614,6 +627,9 @@ class ModelRunner:
|
|||||||
logits = self.run_model(input_ids, positions, is_prefill=False)
|
logits = self.run_model(input_ids, positions, is_prefill=False)
|
||||||
reset_context()
|
reset_context()
|
||||||
|
|
||||||
|
# End cross-layer pipeline
|
||||||
|
offload_engine.end_decode_pipeline()
|
||||||
|
|
||||||
# Only offload when block is full (pos_in_block == block_size - 1)
|
# Only offload when block is full (pos_in_block == block_size - 1)
|
||||||
# This avoids unnecessary offloading on every decode step
|
# This avoids unnecessary offloading on every decode step
|
||||||
if pos_in_block == self.block_size - 1:
|
if pos_in_block == self.block_size - 1:
|
||||||
|
|||||||
@@ -35,7 +35,29 @@ class Scheduler:
|
|||||||
if Observer.ttft_start == 0:
|
if Observer.ttft_start == 0:
|
||||||
Observer.ttft_start = perf_counter_ns()
|
Observer.ttft_start = perf_counter_ns()
|
||||||
seq = self.waiting[0]
|
seq = self.waiting[0]
|
||||||
if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.kvcache_manager.can_allocate(seq):
|
|
||||||
|
# Check if sequence is too large
|
||||||
|
if not self.running and num_seqs == 0:
|
||||||
|
# First sequence, give clear error if it can't be scheduled
|
||||||
|
if len(seq) > self.max_num_batched_tokens:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Sequence too long: {len(seq)} tokens exceeds "
|
||||||
|
f"max_num_batched_tokens={self.max_num_batched_tokens}. "
|
||||||
|
f"Increase max_num_batched_tokens (set equal to max_model_len for long sequences)."
|
||||||
|
)
|
||||||
|
if not self.kvcache_manager.can_allocate(seq):
|
||||||
|
blocks_needed = seq.num_blocks
|
||||||
|
blocks_available = self.kvcache_manager.num_free_blocks
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Cannot allocate KV cache for sequence: "
|
||||||
|
f"need {blocks_needed} blocks ({len(seq)} tokens), "
|
||||||
|
f"but only {blocks_available} blocks available. "
|
||||||
|
f"Increase max_model_len to allocate more blocks."
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_batched_tokens + len(seq) > self.max_num_batched_tokens:
|
||||||
|
break
|
||||||
|
if not self.kvcache_manager.can_allocate(seq):
|
||||||
break
|
break
|
||||||
num_seqs += 1
|
num_seqs += 1
|
||||||
self.kvcache_manager.allocate(seq)
|
self.kvcache_manager.allocate(seq)
|
||||||
@@ -60,7 +82,7 @@ class Scheduler:
|
|||||||
num_seqs += 1
|
num_seqs += 1
|
||||||
self.kvcache_manager.may_append(seq)
|
self.kvcache_manager.may_append(seq)
|
||||||
scheduled_seqs.append(seq)
|
scheduled_seqs.append(seq)
|
||||||
assert scheduled_seqs
|
assert scheduled_seqs, "No sequences scheduled - this should not happen"
|
||||||
self.running.extendleft(reversed(scheduled_seqs))
|
self.running.extendleft(reversed(scheduled_seqs))
|
||||||
return scheduled_seqs, False
|
return scheduled_seqs, False
|
||||||
|
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ class SequenceStatus(Enum):
|
|||||||
|
|
||||||
|
|
||||||
class Sequence:
|
class Sequence:
|
||||||
block_size = 4096
|
block_size = 1024
|
||||||
counter = count()
|
counter = count()
|
||||||
|
|
||||||
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
|
def __init__(self, token_ids: list[int], sampling_params = SamplingParams()):
|
||||||
@@ -34,6 +34,14 @@ class Sequence:
|
|||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return self.token_ids[key]
|
return self.token_ids[key]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
ids = self.token_ids
|
||||||
|
if len(ids) > 20:
|
||||||
|
ids_str = "[" + ", ".join(map(str, ids[:10])) + ", ..., " + ", ".join(map(str, ids[-5:])) + "]"
|
||||||
|
else:
|
||||||
|
ids_str = str(ids)
|
||||||
|
return f"Seq(id={self.seq_id}, status={self.status.name}, tokens={self.num_tokens}, ids={ids_str})"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_finished(self):
|
def is_finished(self):
|
||||||
return self.status == SequenceStatus.FINISHED
|
return self.status == SequenceStatus.FINISHED
|
||||||
|
|||||||
@@ -56,14 +56,26 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
# Need CPU offload: use hybrid manager
|
# Need CPU offload: use hybrid manager
|
||||||
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
||||||
from nanovllm.kvcache.policies import get_policy
|
from nanovllm.kvcache.policies import get_policy
|
||||||
|
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
|
||||||
policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
eviction_policy = get_policy(getattr(config, 'offload_policy', 'lru'))
|
||||||
|
|
||||||
|
# Create sparse policy from config enum
|
||||||
|
# Quest is decode-only: prefill returns all blocks (query=None), decode does Top-K
|
||||||
|
sparse_policy_type = getattr(config, 'sparse_policy', SparsePolicyType.FULL)
|
||||||
|
sparse_policy = create_sparse_policy(
|
||||||
|
sparse_policy_type,
|
||||||
|
topk_blocks=getattr(config, 'sparse_topk_blocks', 8),
|
||||||
|
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
|
||||||
|
)
|
||||||
|
|
||||||
return HybridKVCacheManager(
|
return HybridKVCacheManager(
|
||||||
num_gpu_slots=num_gpu_blocks,
|
num_gpu_slots=num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
num_cpu_blocks=num_cpu_blocks,
|
||||||
block_size=config.kvcache_block_size,
|
block_size=config.kvcache_block_size,
|
||||||
policy=policy,
|
policy=eviction_policy,
|
||||||
|
sparse_policy=sparse_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_cpu_blocks: int,
|
num_cpu_blocks: int,
|
||||||
block_size: int,
|
block_size: int,
|
||||||
policy: Optional[EvictionPolicy] = None,
|
policy: Optional[EvictionPolicy] = None,
|
||||||
|
sparse_policy: "SparsePolicy" = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize hybrid manager with CPU-primary ring buffer design.
|
Initialize hybrid manager with CPU-primary ring buffer design.
|
||||||
@@ -102,6 +103,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
||||||
block_size: Tokens per block
|
block_size: Tokens per block
|
||||||
policy: Eviction policy (default: LRU, used for prefix cache management)
|
policy: Eviction policy (default: LRU, used for prefix cache management)
|
||||||
|
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
|
||||||
"""
|
"""
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self.num_gpu_slots = num_gpu_slots
|
self.num_gpu_slots = num_gpu_slots
|
||||||
@@ -113,6 +115,9 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
# Eviction policy
|
# Eviction policy
|
||||||
self.policy = policy or LRUPolicy()
|
self.policy = policy or LRUPolicy()
|
||||||
|
|
||||||
|
# Sparse attention policy (set at construction time, immutable)
|
||||||
|
self.sparse_policy = sparse_policy
|
||||||
|
|
||||||
# Logical blocks (what sequences reference) - one per CPU block
|
# Logical blocks (what sequences reference) - one per CPU block
|
||||||
self.logical_blocks: List[LogicalBlock] = [
|
self.logical_blocks: List[LogicalBlock] = [
|
||||||
LogicalBlock(i) for i in range(self.total_blocks)
|
LogicalBlock(i) for i in range(self.total_blocks)
|
||||||
@@ -128,6 +133,9 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
self.cpu_block_to_logical: Dict[int, int] = {} # cpu_block -> logical_id
|
self.cpu_block_to_logical: Dict[int, int] = {} # cpu_block -> logical_id
|
||||||
|
|
||||||
# Prefix cache (uses logical block IDs)
|
# Prefix cache (uses logical block IDs)
|
||||||
|
# NOTE: Currently WRITE-ONLY in offload mode - hashes are stored but never
|
||||||
|
#> used for cache hit detection. This is intentional: offload mode always
|
||||||
|
#> allocates new blocks and doesn't reuse existing ones.
|
||||||
self.hash_to_logical_id: Dict[int, int] = {}
|
self.hash_to_logical_id: Dict[int, int] = {}
|
||||||
|
|
||||||
# Step counter for policy
|
# Step counter for policy
|
||||||
@@ -146,8 +154,9 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
# Key: sequence id, Value: starting position where decode began in current block
|
# Key: sequence id, Value: starting position where decode began in current block
|
||||||
self._decode_start_pos: Dict[int, int] = {}
|
self._decode_start_pos: Dict[int, int] = {}
|
||||||
|
|
||||||
# Sparse attention policy (optional)
|
# Track original prefill length (for correct last_block_valid_tokens calculation)
|
||||||
self.sparse_policy: Optional["SparsePolicy"] = None
|
# Key: sequence id, Value: number of tokens from prefill (before decode started)
|
||||||
|
self._prefill_len: Dict[int, int] = {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def block_size(self) -> int:
|
def block_size(self) -> int:
|
||||||
@@ -173,6 +182,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
sparse_policy=self.sparse_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
@@ -180,24 +190,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
assert self.offload_engine is not None
|
assert self.offload_engine is not None
|
||||||
return self.offload_engine.get_layer_cache(layer_id)
|
return self.offload_engine.get_layer_cache(layer_id)
|
||||||
|
|
||||||
def set_sparse_policy(self, policy: "SparsePolicy") -> None:
|
|
||||||
"""
|
|
||||||
Set sparse attention policy for block selection.
|
|
||||||
|
|
||||||
The sparse policy determines which KV blocks to load from CPU
|
|
||||||
for each query chunk during chunked attention computation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
policy: SparsePolicy instance (e.g., VerticalSlashPolicy, QuestPolicy)
|
|
||||||
|
|
||||||
Example:
|
|
||||||
from nanovllm.kvcache.sparse import VerticalSlashPolicy, VerticalSlashConfig
|
|
||||||
policy = VerticalSlashPolicy(VerticalSlashConfig(num_sink_blocks=2))
|
|
||||||
manager.set_sparse_policy(policy)
|
|
||||||
"""
|
|
||||||
self.sparse_policy = policy
|
|
||||||
logger.info(f"Sparse attention policy set: {policy}")
|
|
||||||
|
|
||||||
def can_allocate(self, seq: Sequence) -> bool:
|
def can_allocate(self, seq: Sequence) -> bool:
|
||||||
"""Check if we can allocate blocks for a new sequence."""
|
"""Check if we can allocate blocks for a new sequence."""
|
||||||
return len(self.free_logical_ids) >= seq.num_blocks
|
return len(self.free_logical_ids) >= seq.num_blocks
|
||||||
@@ -254,14 +246,10 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
pos_in_block = seq_len % self._block_size
|
pos_in_block = seq_len % self._block_size
|
||||||
|
|
||||||
if pos_in_block == 1:
|
if pos_in_block == 1:
|
||||||
# Need new block
|
# Need new block (previous block is full)
|
||||||
assert last_block.hash != -1
|
|
||||||
|
|
||||||
logical_id = self.free_logical_ids.popleft()
|
logical_id = self.free_logical_ids.popleft()
|
||||||
block = self.logical_blocks[logical_id]
|
block = self.logical_blocks[logical_id]
|
||||||
block.ref_count = 1
|
block.ref_count = 1
|
||||||
block.hash = -1
|
|
||||||
block.token_ids = []
|
|
||||||
|
|
||||||
# Allocate new block to CPU (ring buffer mode)
|
# Allocate new block to CPU (ring buffer mode)
|
||||||
if not self.free_cpu_blocks:
|
if not self.free_cpu_blocks:
|
||||||
@@ -275,17 +263,13 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
block_table.append(logical_id)
|
block_table.append(logical_id)
|
||||||
|
|
||||||
elif pos_in_block == 0:
|
elif pos_in_block == 0:
|
||||||
# Block is full, update hash for prefix cache
|
# Block is full
|
||||||
assert last_block.hash == -1
|
# NOTE: Prefix cache disabled in offload mode
|
||||||
token_ids = seq.block(seq.num_blocks - 1)
|
# If enabled, would compute hash and update:
|
||||||
prefix_hash = (
|
# h = self.compute_hash(seq.block(seq.num_blocks - 1), prefix_hash)
|
||||||
self.logical_blocks[block_table[-2]].hash
|
# last_block.hash = h
|
||||||
if len(block_table) > 1 else -1
|
# self.hash_to_logical_id[h] = last_logical_id
|
||||||
)
|
pass
|
||||||
h = self.compute_hash(token_ids, prefix_hash)
|
|
||||||
last_block.hash = h
|
|
||||||
last_block.token_ids = token_ids.copy()
|
|
||||||
self.hash_to_logical_id[h] = last_logical_id
|
|
||||||
|
|
||||||
def prepare_for_attention(
|
def prepare_for_attention(
|
||||||
self,
|
self,
|
||||||
@@ -365,8 +349,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
assert not seq.block_table, "Sequence already has blocks"
|
assert not seq.block_table, "Sequence already has blocks"
|
||||||
|
|
||||||
h = -1 # Running hash for prefix cache
|
|
||||||
|
|
||||||
for i in range(seq.num_blocks):
|
for i in range(seq.num_blocks):
|
||||||
# Allocate CPU block
|
# Allocate CPU block
|
||||||
if not self.free_cpu_blocks:
|
if not self.free_cpu_blocks:
|
||||||
@@ -377,19 +359,10 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
|
|
||||||
cpu_block_id = self.free_cpu_blocks.popleft()
|
cpu_block_id = self.free_cpu_blocks.popleft()
|
||||||
|
|
||||||
# Get token IDs for this block and compute hash
|
|
||||||
token_ids = seq.block(i)
|
|
||||||
if len(token_ids) == self._block_size:
|
|
||||||
h = self.compute_hash(token_ids, h)
|
|
||||||
else:
|
|
||||||
h = -1 # Incomplete block
|
|
||||||
|
|
||||||
# Allocate logical block
|
# Allocate logical block
|
||||||
logical_id = self.free_logical_ids.popleft()
|
logical_id = self.free_logical_ids.popleft()
|
||||||
block = self.logical_blocks[logical_id]
|
block = self.logical_blocks[logical_id]
|
||||||
block.ref_count = 1
|
block.ref_count = 1
|
||||||
block.hash = h
|
|
||||||
block.token_ids = token_ids.copy() if len(token_ids) == self._block_size else []
|
|
||||||
block.location = BlockLocation.CPU
|
block.location = BlockLocation.CPU
|
||||||
block.cpu_block_id = cpu_block_id
|
block.cpu_block_id = cpu_block_id
|
||||||
block.gpu_slot = -1
|
block.gpu_slot = -1
|
||||||
@@ -397,9 +370,11 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||||
seq.block_table.append(logical_id)
|
seq.block_table.append(logical_id)
|
||||||
|
|
||||||
# Update prefix cache
|
# NOTE: Prefix cache disabled in offload mode
|
||||||
if h != -1:
|
# If enabled, would compute hash and update:
|
||||||
self.hash_to_logical_id[h] = logical_id
|
# h = self.compute_hash(seq.block(i), prefix_hash)
|
||||||
|
# block.hash = h
|
||||||
|
# self.hash_to_logical_id[h] = logical_id
|
||||||
|
|
||||||
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
|
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
|
||||||
"""
|
"""
|
||||||
@@ -542,6 +517,26 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
seq_id = id(seq)
|
seq_id = id(seq)
|
||||||
self._decode_start_pos[seq_id] = 0
|
self._decode_start_pos[seq_id] = 0
|
||||||
|
|
||||||
|
def get_prefill_len(self, seq: Sequence) -> int:
|
||||||
|
"""
|
||||||
|
Get the original prefill length for a sequence.
|
||||||
|
|
||||||
|
This is cached on first call to ensure correct last_block_valid_tokens
|
||||||
|
calculation during decode (the CPU blocks don't change after prefill).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: Sequence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of tokens from prefill (before decode started)
|
||||||
|
"""
|
||||||
|
seq_id = id(seq)
|
||||||
|
if seq_id not in self._prefill_len:
|
||||||
|
# First decode step - store the prefill length
|
||||||
|
# len(seq) - 1 because current len includes the first decode token
|
||||||
|
self._prefill_len[seq_id] = len(seq) - 1
|
||||||
|
return self._prefill_len[seq_id]
|
||||||
|
|
||||||
def clear_decode_tracking(self, seq: Sequence) -> None:
|
def clear_decode_tracking(self, seq: Sequence) -> None:
|
||||||
"""
|
"""
|
||||||
Clear decode position tracking for sequence.
|
Clear decode position tracking for sequence.
|
||||||
@@ -553,6 +548,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
seq_id = id(seq)
|
seq_id = id(seq)
|
||||||
self._decode_start_pos.pop(seq_id, None)
|
self._decode_start_pos.pop(seq_id, None)
|
||||||
|
self._prefill_len.pop(seq_id, None)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -17,6 +17,11 @@ from nanovllm.kvcache.kernels import gathered_copy_kv
|
|||||||
from nanovllm.comm import memcpy_2d_async
|
from nanovllm.comm import memcpy_2d_async
|
||||||
from nanovllm.utils.logger import get_logger
|
from nanovllm.utils.logger import get_logger
|
||||||
|
|
||||||
|
# Import for type hints only (avoid circular import)
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanovllm.kvcache.sparse import SparsePolicy
|
||||||
|
|
||||||
logger = get_logger("offload_engine")
|
logger = get_logger("offload_engine")
|
||||||
|
|
||||||
|
|
||||||
@@ -35,14 +40,13 @@ class OffloadEngine:
|
|||||||
High-performance CPU-GPU async transfer engine for KV cache offloading.
|
High-performance CPU-GPU async transfer engine for KV cache offloading.
|
||||||
|
|
||||||
Memory layout:
|
Memory layout:
|
||||||
- GPU cache: [num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]
|
- GPU cache: [num_gpu_blocks, block_size, kv_heads, head_dim] (no layer dimension)
|
||||||
- CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned)
|
- CPU cache: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim] (pinned)
|
||||||
- Gather indices: [num_layers, num_gpu_blocks] (fixed address, variable content)
|
|
||||||
|
|
||||||
CUDA Graph compatibility:
|
Features:
|
||||||
- gathered_h2d_layer() can be captured into CUDA graphs
|
- Unified ring buffer for chunked prefill/decode
|
||||||
- update_gather_indices() is called outside graphs to prepare indices
|
- Per-layer prefill buffer for async offload
|
||||||
- All tensor addresses remain fixed across graph replays
|
- Cross-layer pipeline for decode with double-buffering
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -55,6 +59,7 @@ class OffloadEngine:
|
|||||||
head_dim: int,
|
head_dim: int,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
num_streams: int = 4,
|
num_streams: int = 4,
|
||||||
|
sparse_policy: "SparsePolicy" = None,
|
||||||
):
|
):
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_gpu_blocks = num_gpu_blocks
|
self.num_gpu_blocks = num_gpu_blocks
|
||||||
@@ -136,6 +141,64 @@ class OffloadEngine:
|
|||||||
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
||||||
|
|
||||||
|
# ========== Cross-layer pipeline buffers for decode ==========
|
||||||
|
# Double-buffered layer cache for pipelined decode:
|
||||||
|
# - Buffer A: Current layer's prefilled KV being computed
|
||||||
|
# - Buffer B: Next layer's prefilled KV being loaded
|
||||||
|
# Shape: [max_prefill_blocks, block_size, kv_heads, head_dim]
|
||||||
|
# Memory: 2 * max_prefill_blocks * block_size * kv_heads * head_dim * dtype_size
|
||||||
|
max_prefill_blocks = num_cpu_blocks # Can hold all prefill blocks
|
||||||
|
self.layer_k_buffer_a = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.layer_v_buffer_a = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.layer_k_buffer_b = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.layer_v_buffer_b = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
layer_buf_mb = 4 * max_prefill_blocks * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
|
logger.info(f" Cross-layer pipeline buffers: {layer_buf_mb:.1f} MB ({max_prefill_blocks} blocks × 2)")
|
||||||
|
|
||||||
|
# Pipeline state tracking
|
||||||
|
self._pipeline_active = False
|
||||||
|
self._pipeline_current_buffer = 0 # 0 = buffer A, 1 = buffer B
|
||||||
|
self._pipeline_next_layer_event = torch.cuda.Event()
|
||||||
|
self._pipeline_cpu_blocks: list = [] # CPU block IDs to load
|
||||||
|
self._pipeline_num_blocks = 0
|
||||||
|
self._pipeline_layer_stream = torch.cuda.Stream() # Dedicated stream for layer loading
|
||||||
|
|
||||||
|
# ========== Per-layer prefill buffer for async offload ==========
|
||||||
|
# During chunked prefill, all layers share the same GPU slot. This means
|
||||||
|
# each layer must wait for offload to complete before the next layer can
|
||||||
|
# write to the same slot. This serializes offloads and hurts performance.
|
||||||
|
# Solution: Maintain separate per-layer buffers for prefill.
|
||||||
|
# Each layer writes to its own buffer, enabling fully async offloads.
|
||||||
|
# Shape: [num_layers, block_size, kv_heads, head_dim]
|
||||||
|
self.prefill_k_buffer = torch.zeros(
|
||||||
|
num_layers, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.prefill_v_buffer = torch.zeros(
|
||||||
|
num_layers, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
prefill_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
|
logger.info(f" Per-layer prefill buffer: {prefill_buf_mb:.1f} MB")
|
||||||
|
|
||||||
|
# Per-layer offload events for async prefill offload
|
||||||
|
# Each layer has its own event to track offload completion
|
||||||
|
self.prefill_offload_events = [torch.cuda.Event() for _ in range(num_layers)]
|
||||||
|
# Per-layer transfer streams for parallel offloads
|
||||||
|
self.prefill_offload_streams = [torch.cuda.Stream() for _ in range(num_layers)]
|
||||||
|
|
||||||
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
||||||
self.k_cache_cpu = torch.zeros(
|
self.k_cache_cpu = torch.zeros(
|
||||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||||
@@ -146,19 +209,6 @@ class OffloadEngine:
|
|||||||
dtype=dtype, device="cpu", pin_memory=True
|
dtype=dtype, device="cpu", pin_memory=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# ========== Fixed-address gather indices (content is variable) ==========
|
|
||||||
# gather_indices[layer][i] = CPU block id to copy to GPU slot i
|
|
||||||
# -1 means no-op (skip this slot)
|
|
||||||
self.gather_indices_cpu = torch.empty(
|
|
||||||
num_layers, num_gpu_blocks,
|
|
||||||
dtype=torch.int64, device="cpu", pin_memory=True
|
|
||||||
)
|
|
||||||
self.gather_indices_cpu.fill_(-1)
|
|
||||||
self.gather_indices_gpu = torch.full(
|
|
||||||
(num_layers, num_gpu_blocks), -1,
|
|
||||||
dtype=torch.int64, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log memory allocation
|
# Log memory allocation
|
||||||
gpu_mem_mb = self.gpu_memory_bytes() / (1024 * 1024)
|
gpu_mem_mb = self.gpu_memory_bytes() / (1024 * 1024)
|
||||||
cpu_mem_mb = self.cpu_memory_bytes() / (1024 * 1024)
|
cpu_mem_mb = self.cpu_memory_bytes() / (1024 * 1024)
|
||||||
@@ -201,7 +251,7 @@ class OffloadEngine:
|
|||||||
# This prevents undefined behavior on first load_to_slot_layer call
|
# This prevents undefined behavior on first load_to_slot_layer call
|
||||||
for slot_idx in range(self.num_ring_slots):
|
for slot_idx in range(self.num_ring_slots):
|
||||||
self.ring_slot_compute_done[slot_idx].record()
|
self.ring_slot_compute_done[slot_idx].record()
|
||||||
torch.cuda.synchronize() # Ensure all events are recorded
|
# torch.cuda.synchronize() # Ensure all events are recorded
|
||||||
|
|
||||||
# ========== Event tracking for async transfers ==========
|
# ========== Event tracking for async transfers ==========
|
||||||
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
self.pending_events: Dict[Tuple[int, int], torch.cuda.Event] = {}
|
||||||
@@ -210,320 +260,8 @@ class OffloadEngine:
|
|||||||
self._debug_mode = False
|
self._debug_mode = False
|
||||||
self._debug_hooks: List = [] # External hooks for debug events
|
self._debug_hooks: List = [] # External hooks for debug events
|
||||||
|
|
||||||
def _get_next_stream(self) -> torch.cuda.Stream:
|
# ========== Sparse attention policy (set at construction time) ==========
|
||||||
"""Round-robin stream selection for parallel transfers."""
|
self.sparse_policy = sparse_policy
|
||||||
stream = self.transfer_streams[self._stream_idx]
|
|
||||||
self._stream_idx = (self._stream_idx + 1) % len(self.transfer_streams)
|
|
||||||
return stream
|
|
||||||
|
|
||||||
# ========== CUDA Graph compatible methods ==========
|
|
||||||
# NOTE: These methods need to be updated for the new GPU cache architecture.
|
|
||||||
# GPU cache no longer has layer dimension, so gathered copy semantics change.
|
|
||||||
# For now, these are kept for reference but should not be used without updating.
|
|
||||||
|
|
||||||
def gathered_h2d_layer(self, layer_id: int) -> None:
|
|
||||||
"""
|
|
||||||
Execute gathered H2D copy for a single layer.
|
|
||||||
|
|
||||||
WARNING: This method needs updating for new GPU cache architecture.
|
|
||||||
GPU cache no longer has layer dimension.
|
|
||||||
"""
|
|
||||||
# GPU cache has no layer dimension - use flat indexing
|
|
||||||
# Source is CPU[layer_id], dest is GPU (shared across layers)
|
|
||||||
gathered_copy_kv(
|
|
||||||
k_src=self.k_cache_cpu[layer_id],
|
|
||||||
v_src=self.v_cache_cpu[layer_id],
|
|
||||||
k_dst=self.k_cache_gpu, # No layer indexing
|
|
||||||
v_dst=self.v_cache_gpu, # No layer indexing
|
|
||||||
indices=self.gather_indices_gpu[layer_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
def gathered_h2d_all_layers(self) -> None:
|
|
||||||
"""
|
|
||||||
Execute gathered H2D copy for all layers.
|
|
||||||
|
|
||||||
WARNING: In new architecture, GPU slots are shared across layers.
|
|
||||||
This method would overwrite slots multiple times. Not recommended.
|
|
||||||
"""
|
|
||||||
for layer_id in range(self.num_layers):
|
|
||||||
self.gathered_h2d_layer(layer_id)
|
|
||||||
|
|
||||||
def update_gather_indices(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
mappings: List[Tuple[int, int]],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Update gather indices for a layer (call OUTSIDE CUDA graph).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index
|
|
||||||
mappings: List of (cpu_block_id, gpu_slot) tuples
|
|
||||||
Only these slots will be updated; others keep their values
|
|
||||||
"""
|
|
||||||
for cpu_block_id, gpu_slot in mappings:
|
|
||||||
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
|
|
||||||
|
|
||||||
# Async copy to GPU
|
|
||||||
self.gather_indices_gpu[layer_id].copy_(
|
|
||||||
self.gather_indices_cpu[layer_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_gather_indices_all_layers(
|
|
||||||
self,
|
|
||||||
mappings_per_layer: List[List[Tuple[int, int]]],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Update gather indices for all layers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mappings_per_layer: mappings_per_layer[layer_id] = [(cpu_block_id, gpu_slot), ...]
|
|
||||||
"""
|
|
||||||
for layer_id, mappings in enumerate(mappings_per_layer):
|
|
||||||
for cpu_block_id, gpu_slot in mappings:
|
|
||||||
self.gather_indices_cpu[layer_id, gpu_slot] = cpu_block_id
|
|
||||||
|
|
||||||
# Batch copy all layers
|
|
||||||
self.gather_indices_gpu.copy_(self.gather_indices_cpu, non_blocking=True)
|
|
||||||
|
|
||||||
def clear_gather_indices(self, layer_id: Optional[int] = None) -> None:
|
|
||||||
"""
|
|
||||||
Clear gather indices (set all to -1, meaning no-op).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: If provided, clear only this layer; otherwise clear all
|
|
||||||
"""
|
|
||||||
if layer_id is not None:
|
|
||||||
self.gather_indices_cpu[layer_id].fill_(-1)
|
|
||||||
self.gather_indices_gpu[layer_id].fill_(-1)
|
|
||||||
else:
|
|
||||||
self.gather_indices_cpu.fill_(-1)
|
|
||||||
self.gather_indices_gpu.fill_(-1)
|
|
||||||
|
|
||||||
# ========== Async transfer methods (for prefill, outside CUDA graph) ==========
|
|
||||||
|
|
||||||
def prefetch_block_async(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_id: int,
|
|
||||||
gpu_block_id: int,
|
|
||||||
) -> torch.cuda.Event:
|
|
||||||
"""
|
|
||||||
Async prefetch a single block from CPU to GPU.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
cpu_block_id: Source block in CPU cache
|
|
||||||
gpu_block_id: Destination slot in GPU cache
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CUDA event that signals completion
|
|
||||||
"""
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
event = torch.cuda.Event()
|
|
||||||
|
|
||||||
logger.debug(f"H2D prefetch: layer={layer_id}, CPU[{cpu_block_id}] -> GPU[{gpu_block_id}]")
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_block_id].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_block_id].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
event.record()
|
|
||||||
|
|
||||||
self.pending_events[(layer_id, gpu_block_id)] = event
|
|
||||||
return event
|
|
||||||
|
|
||||||
def prefetch_blocks_batch_async(
|
|
||||||
self,
|
|
||||||
transfers: List[Tuple[int, int, int]], # [(layer_id, cpu_block_id, gpu_block_id), ...]
|
|
||||||
) -> List[torch.cuda.Event]:
|
|
||||||
"""
|
|
||||||
Batch async prefetch multiple blocks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
transfers: List of (layer_id, cpu_block_id, gpu_block_id) tuples
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of CUDA events for each transfer
|
|
||||||
"""
|
|
||||||
events = []
|
|
||||||
for layer_id, cpu_block_id, gpu_block_id in transfers:
|
|
||||||
event = self.prefetch_block_async(layer_id, cpu_block_id, gpu_block_id)
|
|
||||||
events.append(event)
|
|
||||||
return events
|
|
||||||
|
|
||||||
def offload_block_async(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
gpu_block_id: int,
|
|
||||||
cpu_block_id: int,
|
|
||||||
) -> torch.cuda.Event:
|
|
||||||
"""
|
|
||||||
Async offload a block from GPU to CPU.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
gpu_block_id: Source slot in GPU cache
|
|
||||||
cpu_block_id: Destination block in CPU cache
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CUDA event that signals completion
|
|
||||||
"""
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
event = torch.cuda.Event()
|
|
||||||
|
|
||||||
logger.debug(f"D2H offload: layer={layer_id}, GPU[{gpu_block_id}] -> CPU[{cpu_block_id}]")
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
# Wait for any compute using this block
|
|
||||||
stream.wait_stream(self.compute_stream)
|
|
||||||
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
|
||||||
self.k_cache_gpu[gpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
|
||||||
self.v_cache_gpu[gpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
event.record()
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
|
||||||
def offload_blocks_batch_async(
|
|
||||||
self,
|
|
||||||
transfers: List[Tuple[int, int, int]], # [(layer_id, gpu_block_id, cpu_block_id), ...]
|
|
||||||
) -> List[torch.cuda.Event]:
|
|
||||||
"""
|
|
||||||
Batch async offload multiple blocks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
transfers: List of (layer_id, gpu_block_id, cpu_block_id) tuples
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of CUDA events
|
|
||||||
"""
|
|
||||||
events = []
|
|
||||||
for layer_id, gpu_block_id, cpu_block_id in transfers:
|
|
||||||
event = self.offload_block_async(layer_id, gpu_block_id, cpu_block_id)
|
|
||||||
events.append(event)
|
|
||||||
return events
|
|
||||||
|
|
||||||
# ========== Chunked Decode: Load CPU blocks to GPU slots ==========
|
|
||||||
|
|
||||||
def load_cpu_blocks_to_gpu_slots(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_ids: List[int],
|
|
||||||
gpu_slot_ids: List[int],
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Load CPU blocks to specific GPU slots for chunked decode.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
cpu_block_ids: List of CPU block IDs to load
|
|
||||||
gpu_slot_ids: List of GPU slot IDs to load into (must be same length)
|
|
||||||
"""
|
|
||||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
|
||||||
|
|
||||||
if cpu_block_ids:
|
|
||||||
logger.debug(f"H2D chunked load: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
|
|
||||||
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Wait for transfer to complete
|
|
||||||
stream.synchronize()
|
|
||||||
|
|
||||||
def load_cpu_blocks_to_gpu_slots_async(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_ids: List[int],
|
|
||||||
gpu_slot_ids: List[int],
|
|
||||||
) -> torch.cuda.Event:
|
|
||||||
"""
|
|
||||||
Async version: Load CPU blocks to GPU slots.
|
|
||||||
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
layer_id: Layer index (for CPU cache)
|
|
||||||
cpu_block_ids: List of CPU block IDs to load
|
|
||||||
gpu_slot_ids: List of GPU slot IDs to load into
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CUDA event to wait on
|
|
||||||
"""
|
|
||||||
assert len(cpu_block_ids) == len(gpu_slot_ids)
|
|
||||||
|
|
||||||
if cpu_block_ids:
|
|
||||||
logger.debug(f"H2D chunked load async: layer={layer_id}, CPU{cpu_block_ids} -> GPU{gpu_slot_ids}")
|
|
||||||
|
|
||||||
stream = self._get_next_stream()
|
|
||||||
event = torch.cuda.Event()
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
event.record()
|
|
||||||
|
|
||||||
return event
|
|
||||||
|
|
||||||
# NOTE: load_cpu_blocks_to_gpu_slots_all_layers removed - GPU cache no longer has
|
|
||||||
# layer dimension. Each GPU slot holds data for ONE layer at a time.
|
|
||||||
|
|
||||||
# ========== Synchronization methods ==========
|
|
||||||
|
|
||||||
def wait_for_block(self, layer_id: int, gpu_block_id: int) -> None:
|
|
||||||
"""Wait for a specific block's transfer to complete."""
|
|
||||||
key = (layer_id, gpu_block_id)
|
|
||||||
if key in self.pending_events:
|
|
||||||
self.pending_events[key].synchronize()
|
|
||||||
del self.pending_events[key]
|
|
||||||
|
|
||||||
def wait_all_transfers(self) -> None:
|
|
||||||
"""Wait for all pending transfers to complete."""
|
|
||||||
for stream in self.transfer_streams:
|
|
||||||
stream.synchronize()
|
|
||||||
self.pending_events.clear()
|
|
||||||
|
|
||||||
def sync_indices(self) -> None:
|
|
||||||
"""Synchronize to ensure all index updates are complete."""
|
|
||||||
torch.cuda.default_stream().synchronize()
|
|
||||||
|
|
||||||
# ========== Cache access methods ==========
|
# ========== Cache access methods ==========
|
||||||
|
|
||||||
@@ -538,54 +276,22 @@ class OffloadEngine:
|
|||||||
(k_cache, v_cache) tensors
|
(k_cache, v_cache) tensors
|
||||||
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
# GPU cache is shared across all layers (no layer dimension)
|
|
||||||
return self.k_cache_gpu, self.v_cache_gpu
|
return self.k_cache_gpu, self.v_cache_gpu
|
||||||
|
|
||||||
def get_all_gpu_cache(self) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""
|
|
||||||
Get full GPU K/V cache tensors.
|
|
||||||
|
|
||||||
NOTE: GPU cache has no layer dimension in the new architecture.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(k_cache, v_cache) tensors
|
|
||||||
Shape: [num_gpu_blocks, block_size, kv_heads, head_dim]
|
|
||||||
"""
|
|
||||||
return self.k_cache_gpu, self.v_cache_gpu
|
|
||||||
|
|
||||||
def get_cpu_block(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_id: int,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""
|
|
||||||
Get a specific CPU block's K/V cache.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(k_cache, v_cache) for the block
|
|
||||||
Shape: [block_size, kv_heads, head_dim]
|
|
||||||
"""
|
|
||||||
return (
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id],
|
|
||||||
)
|
|
||||||
|
|
||||||
# ========== Memory info ==========
|
# ========== Memory info ==========
|
||||||
|
|
||||||
def gpu_memory_bytes(self) -> int:
|
def gpu_memory_bytes(self) -> int:
|
||||||
"""Total GPU memory used by KV caches."""
|
"""Total GPU memory used by KV caches."""
|
||||||
return (
|
return (
|
||||||
self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() +
|
self.k_cache_gpu.numel() * self.k_cache_gpu.element_size() +
|
||||||
self.v_cache_gpu.numel() * self.v_cache_gpu.element_size() +
|
self.v_cache_gpu.numel() * self.v_cache_gpu.element_size()
|
||||||
self.gather_indices_gpu.numel() * self.gather_indices_gpu.element_size()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def cpu_memory_bytes(self) -> int:
|
def cpu_memory_bytes(self) -> int:
|
||||||
"""Total CPU memory used by KV caches."""
|
"""Total CPU memory used by KV caches."""
|
||||||
return (
|
return (
|
||||||
self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() +
|
self.k_cache_cpu.numel() * self.k_cache_cpu.element_size() +
|
||||||
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size() +
|
self.v_cache_cpu.numel() * self.v_cache_cpu.element_size()
|
||||||
self.gather_indices_cpu.numel() * self.gather_indices_cpu.element_size()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
@@ -730,7 +436,14 @@ class OffloadEngine:
|
|||||||
"""Wait for slot offload to complete."""
|
"""Wait for slot offload to complete."""
|
||||||
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
self.compute_stream.wait_event(self.ring_slot_offload_done[slot_idx])
|
||||||
|
|
||||||
def offload_slot_layer_to_cpu(self, slot_idx: int, layer_id: int, cpu_block_id: int) -> None:
|
def offload_slot_layer_to_cpu(
|
||||||
|
self,
|
||||||
|
slot_idx: int,
|
||||||
|
layer_id: int,
|
||||||
|
cpu_block_id: int,
|
||||||
|
num_valid_tokens: int = -1,
|
||||||
|
is_prefill: bool = True,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Async offload a ring buffer slot to CPU for one layer.
|
Async offload a ring buffer slot to CPU for one layer.
|
||||||
|
|
||||||
@@ -741,9 +454,21 @@ class OffloadEngine:
|
|||||||
slot_idx: Source GPU slot index
|
slot_idx: Source GPU slot index
|
||||||
layer_id: Target layer in CPU cache
|
layer_id: Target layer in CPU cache
|
||||||
cpu_block_id: Target CPU block ID
|
cpu_block_id: Target CPU block ID
|
||||||
|
num_valid_tokens: Number of valid tokens in this block (-1 = use block_size)
|
||||||
|
is_prefill: True if in prefill phase, False if in decode phase
|
||||||
"""
|
"""
|
||||||
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[layer={layer_id}, block={cpu_block_id}]")
|
||||||
|
|
||||||
|
# Collect metadata BEFORE offload (while k_cache is still on GPU)
|
||||||
|
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||||
|
k_cache = self.k_cache_gpu[slot_idx]
|
||||||
|
|
||||||
|
if self.sparse_policy is not None:
|
||||||
|
if is_prefill:
|
||||||
|
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
|
else:
|
||||||
|
self.sparse_policy.on_decode_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[L{layer_id},B{cpu_block_id}]")
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
with torch.cuda.stream(self.transfer_stream_main):
|
||||||
# Wait for both compute_stream and default stream
|
# Wait for both compute_stream and default stream
|
||||||
@@ -869,102 +594,6 @@ class OffloadEngine:
|
|||||||
v = v.unsqueeze(0)
|
v = v.unsqueeze(0)
|
||||||
return k, v
|
return k, v
|
||||||
|
|
||||||
# ----- Legacy compatibility methods (for decode double-buffering) -----
|
|
||||||
# NOTE: GPU cache has no layer dimension. Layer ID is used for CPU cache indexing only.
|
|
||||||
|
|
||||||
def load_to_compute_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
|
||||||
"""
|
|
||||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
|
||||||
|
|
||||||
Uses first half of decode_load_slots as 'compute' region.
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
"""
|
|
||||||
if not cpu_block_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[:half]
|
|
||||||
num_to_load = min(len(cpu_block_ids), len(slots))
|
|
||||||
|
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
|
||||||
for i in range(num_to_load):
|
|
||||||
cpu_id = cpu_block_ids[i]
|
|
||||||
gpu_slot = slots[i]
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
if num_to_load > 0:
|
|
||||||
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
|
|
||||||
|
|
||||||
def wait_compute_layer(self) -> None:
|
|
||||||
"""Legacy: Wait for 'compute' region loading."""
|
|
||||||
if self.decode_load_slots:
|
|
||||||
self.wait_slot_layer(self.decode_load_slots[0])
|
|
||||||
|
|
||||||
def load_to_prefetch_layer(self, layer_id: int, cpu_block_ids: List[int]) -> None:
|
|
||||||
"""
|
|
||||||
Legacy: Load CPU blocks to decode_load_slots for decode double-buffering.
|
|
||||||
|
|
||||||
Uses second half of decode_load_slots as 'prefetch' region.
|
|
||||||
GPU cache has no layer dimension - layer_id is for CPU cache indexing.
|
|
||||||
"""
|
|
||||||
if not cpu_block_ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[half:]
|
|
||||||
if not slots:
|
|
||||||
slots = self.decode_load_slots # Fallback if only 1-2 slots
|
|
||||||
num_to_load = min(len(cpu_block_ids), len(slots))
|
|
||||||
|
|
||||||
with torch.cuda.stream(self.transfer_stream_main):
|
|
||||||
for i in range(num_to_load):
|
|
||||||
cpu_id = cpu_block_ids[i]
|
|
||||||
gpu_slot = slots[i]
|
|
||||||
# GPU: no layer dimension, CPU: has layer dimension
|
|
||||||
self.k_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
self.v_cache_gpu[gpu_slot].copy_(
|
|
||||||
self.v_cache_cpu[layer_id, cpu_id], non_blocking=True
|
|
||||||
)
|
|
||||||
if num_to_load > 0:
|
|
||||||
self.ring_slot_ready[slots[0]].record(self.transfer_stream_main)
|
|
||||||
|
|
||||||
def wait_prefetch_layer(self) -> None:
|
|
||||||
"""Legacy: Wait for 'prefetch' region loading."""
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[half:]
|
|
||||||
if slots:
|
|
||||||
self.wait_slot_layer(slots[0])
|
|
||||||
elif self.decode_load_slots:
|
|
||||||
self.wait_slot_layer(self.decode_load_slots[0])
|
|
||||||
|
|
||||||
def get_kv_for_compute(
|
|
||||||
self,
|
|
||||||
num_blocks: int,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""Legacy: Get KV from 'compute' region (first half of decode_load_slots)."""
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[:half][:num_blocks]
|
|
||||||
return self.get_kv_for_slots(slots)
|
|
||||||
|
|
||||||
def get_kv_for_prefetch(
|
|
||||||
self,
|
|
||||||
num_blocks: int,
|
|
||||||
) -> Tuple[Tensor, Tensor]:
|
|
||||||
"""Legacy: Get KV from 'prefetch' region (second half of decode_load_slots)."""
|
|
||||||
half = max(1, len(self.decode_load_slots) // 2)
|
|
||||||
slots = self.decode_load_slots[half:]
|
|
||||||
if not slots:
|
|
||||||
slots = self.decode_load_slots
|
|
||||||
slots = slots[:num_blocks]
|
|
||||||
return self.get_kv_for_slots(slots)
|
|
||||||
|
|
||||||
# ========== Debug Hook Interface ==========
|
# ========== Debug Hook Interface ==========
|
||||||
#
|
#
|
||||||
# Minimal generic hook system for debugging.
|
# Minimal generic hook system for debugging.
|
||||||
@@ -1036,3 +665,207 @@ class OffloadEngine:
|
|||||||
if e.__class__.__name__ == 'BdbQuit':
|
if e.__class__.__name__ == 'BdbQuit':
|
||||||
raise
|
raise
|
||||||
logger.warning(f"Debug hook error: {e}")
|
logger.warning(f"Debug hook error: {e}")
|
||||||
|
|
||||||
|
# ========== Cross-layer Pipeline Methods for Decode ==========
|
||||||
|
|
||||||
|
def start_decode_pipeline(self, cpu_block_ids: List[int]) -> None:
|
||||||
|
"""
|
||||||
|
Start cross-layer pipeline for decode.
|
||||||
|
|
||||||
|
Called at the beginning of a decode step to initialize the pipeline.
|
||||||
|
Preloads Layer 0's data into buffer A.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cpu_block_ids: List of CPU block IDs for prefilled blocks
|
||||||
|
"""
|
||||||
|
if not cpu_block_ids:
|
||||||
|
self._pipeline_active = False
|
||||||
|
return
|
||||||
|
|
||||||
|
self._pipeline_active = True
|
||||||
|
self._pipeline_cpu_blocks = cpu_block_ids
|
||||||
|
self._pipeline_num_blocks = len(cpu_block_ids)
|
||||||
|
self._pipeline_current_buffer = 0
|
||||||
|
|
||||||
|
# Preload Layer 0 into buffer A
|
||||||
|
self._load_layer_to_buffer(0, 0) # layer_id=0, buffer_idx=0 (A)
|
||||||
|
|
||||||
|
def get_decode_layer_kv(self, layer_id: int, num_blocks: int) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Get KV cache for a layer during decode.
|
||||||
|
|
||||||
|
If pipeline is active, returns data from the current buffer.
|
||||||
|
Also triggers preloading of the next layer (if not last layer).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Current layer ID
|
||||||
|
num_blocks: Number of blocks to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k_cache, v_cache) tensors, shape: [num_blocks, block_size, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
if not self._pipeline_active:
|
||||||
|
raise RuntimeError("Decode pipeline not active. Call start_decode_pipeline first.")
|
||||||
|
|
||||||
|
# Wait for current layer's data to be ready
|
||||||
|
self.compute_stream.wait_event(self._pipeline_next_layer_event)
|
||||||
|
|
||||||
|
# Get current buffer
|
||||||
|
if self._pipeline_current_buffer == 0:
|
||||||
|
k = self.layer_k_buffer_a[:num_blocks]
|
||||||
|
v = self.layer_v_buffer_a[:num_blocks]
|
||||||
|
else:
|
||||||
|
k = self.layer_k_buffer_b[:num_blocks]
|
||||||
|
v = self.layer_v_buffer_b[:num_blocks]
|
||||||
|
|
||||||
|
# Trigger preloading of next layer (if not last layer)
|
||||||
|
next_layer_id = layer_id + 1
|
||||||
|
if next_layer_id < self.num_layers:
|
||||||
|
# Use the other buffer for next layer
|
||||||
|
next_buffer_idx = 1 - self._pipeline_current_buffer
|
||||||
|
self._load_layer_to_buffer(next_layer_id, next_buffer_idx)
|
||||||
|
# Switch to next buffer for next layer
|
||||||
|
self._pipeline_current_buffer = next_buffer_idx
|
||||||
|
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
def _load_layer_to_buffer(self, layer_id: int, buffer_idx: int) -> None:
|
||||||
|
"""
|
||||||
|
Async load a layer's prefilled blocks to the specified buffer.
|
||||||
|
|
||||||
|
Uses sgDMA for efficient strided transfer from CPU cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index to load
|
||||||
|
buffer_idx: 0 for buffer A, 1 for buffer B
|
||||||
|
"""
|
||||||
|
num_blocks = self._pipeline_num_blocks
|
||||||
|
cpu_block_ids = self._pipeline_cpu_blocks
|
||||||
|
|
||||||
|
# Select target buffer
|
||||||
|
if buffer_idx == 0:
|
||||||
|
k_buffer = self.layer_k_buffer_a
|
||||||
|
v_buffer = self.layer_v_buffer_a
|
||||||
|
else:
|
||||||
|
k_buffer = self.layer_k_buffer_b
|
||||||
|
v_buffer = self.layer_v_buffer_b
|
||||||
|
|
||||||
|
# Load all blocks for this layer using dedicated stream
|
||||||
|
with torch.cuda.stream(self._pipeline_layer_stream):
|
||||||
|
for i, cpu_block_id in enumerate(cpu_block_ids):
|
||||||
|
# Copy from CPU cache (has layer dimension) to GPU buffer
|
||||||
|
k_buffer[i].copy_(
|
||||||
|
self.k_cache_cpu[layer_id, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
v_buffer[i].copy_(
|
||||||
|
self.v_cache_cpu[layer_id, cpu_block_id],
|
||||||
|
non_blocking=True
|
||||||
|
)
|
||||||
|
# Record event when all transfers complete
|
||||||
|
self._pipeline_next_layer_event.record(self._pipeline_layer_stream)
|
||||||
|
|
||||||
|
def end_decode_pipeline(self) -> None:
|
||||||
|
"""
|
||||||
|
End the cross-layer pipeline.
|
||||||
|
|
||||||
|
Called at the end of a decode step to clean up pipeline state.
|
||||||
|
"""
|
||||||
|
if self._pipeline_active:
|
||||||
|
# Ensure all transfers complete before ending
|
||||||
|
self._pipeline_layer_stream.synchronize()
|
||||||
|
self._pipeline_active = False
|
||||||
|
self._pipeline_cpu_blocks = []
|
||||||
|
self._pipeline_num_blocks = 0
|
||||||
|
|
||||||
|
def is_pipeline_active(self) -> bool:
|
||||||
|
"""Check if decode pipeline is currently active."""
|
||||||
|
return self._pipeline_active
|
||||||
|
|
||||||
|
# ========== Per-layer Prefill Buffer Methods ==========
|
||||||
|
# These methods enable async offload during chunked prefill by using
|
||||||
|
# per-layer buffers instead of shared GPU slots.
|
||||||
|
|
||||||
|
def get_prefill_buffer(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Get prefill buffer for a layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k_buffer, v_buffer), shape: [block_size, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
return self.prefill_k_buffer[layer_id], self.prefill_v_buffer[layer_id]
|
||||||
|
|
||||||
|
def get_prefill_buffer_slice(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
num_tokens: int,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Get a slice of prefill buffer for attention computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
num_tokens: Number of valid tokens in current chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k, v) with shape [1, num_tokens, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
k = self.prefill_k_buffer[layer_id, :num_tokens].unsqueeze(0)
|
||||||
|
v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0)
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
def offload_prefill_buffer_async(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
cpu_block_id: int,
|
||||||
|
num_valid_tokens: int = -1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Async offload prefill buffer to CPU (no waiting required).
|
||||||
|
|
||||||
|
This uses per-layer streams and events to enable fully async offloads.
|
||||||
|
Each layer can offload independently without blocking other layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
cpu_block_id: Target CPU block ID
|
||||||
|
num_valid_tokens: Number of valid tokens (-1 = use block_size)
|
||||||
|
"""
|
||||||
|
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||||
|
|
||||||
|
# Collect sparse policy metadata before offload
|
||||||
|
if self.sparse_policy is not None:
|
||||||
|
k_cache = self.prefill_k_buffer[layer_id]
|
||||||
|
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
|
|
||||||
|
# Use per-layer stream for parallel offloads
|
||||||
|
stream = self.prefill_offload_streams[layer_id]
|
||||||
|
|
||||||
|
torch.cuda.nvtx.range_push(f"AsyncPrefillOffload: L{layer_id}->CPU[{cpu_block_id}]")
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
# Wait for compute to finish writing to prefill buffer
|
||||||
|
stream.wait_stream(self.compute_stream)
|
||||||
|
|
||||||
|
# Copy from prefill buffer to CPU
|
||||||
|
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||||
|
self.prefill_k_buffer[layer_id], non_blocking=True
|
||||||
|
)
|
||||||
|
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||||
|
self.prefill_v_buffer[layer_id], non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record completion event
|
||||||
|
self.prefill_offload_events[layer_id].record(stream)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
|
def wait_all_prefill_offloads(self) -> None:
|
||||||
|
"""Wait for all prefill buffer offloads to complete."""
|
||||||
|
for stream in self.prefill_offload_streams:
|
||||||
|
stream.synchronize()
|
||||||
|
|
||||||
|
def wait_prefill_offload(self, layer_id: int) -> None:
|
||||||
|
"""Wait for a specific layer's prefill offload to complete."""
|
||||||
|
self.prefill_offload_events[layer_id].synchronize()
|
||||||
|
|||||||
@@ -5,86 +5,67 @@ Provides pluggable policies for selecting which KV blocks to load
|
|||||||
during chunked attention with CPU offload.
|
during chunked attention with CPU offload.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from nanovllm.kvcache.sparse import SparsePolicy, PolicyContext
|
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
||||||
from nanovllm.kvcache.sparse import VerticalSlashPolicy, QuestPolicy
|
|
||||||
|
|
||||||
# Use built-in policy
|
# Create policy using factory function
|
||||||
policy = VerticalSlashPolicy(VerticalSlashConfig())
|
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
|
||||||
|
|
||||||
# Or create custom policy
|
# Or create custom policy
|
||||||
class MyPolicy(SparsePolicy):
|
class MyPolicy(SparsePolicy):
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = True
|
||||||
|
|
||||||
def select_blocks(self, available_blocks, ctx):
|
def select_blocks(self, available_blocks, ctx):
|
||||||
return available_blocks[:5] # Just first 5 blocks
|
return available_blocks[:5] # Just first 5 blocks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
from nanovllm.kvcache.sparse.vertical_slash import VerticalSlashPolicy, VerticalSlashConfig
|
|
||||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||||
from nanovllm.kvcache.sparse.streaming_llm import StreamingLLMPolicy, StreamingLLMConfig
|
|
||||||
from nanovllm.kvcache.sparse.hybrid import HybridPolicy
|
|
||||||
|
|
||||||
# Built-in policy registry
|
|
||||||
BUILTIN_SPARSE_POLICIES = {
|
|
||||||
"full": FullAttentionPolicy,
|
|
||||||
"vertical_slash": VerticalSlashPolicy,
|
|
||||||
"streaming_llm": StreamingLLMPolicy,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_sparse_policy(policy_name: str, **kwargs) -> SparsePolicy:
|
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||||
"""
|
"""
|
||||||
Get a sparse attention policy instance by name.
|
Create a sparse policy instance from an enum type.
|
||||||
|
|
||||||
|
The returned policy is not yet initialized. Call policy.initialize()
|
||||||
|
or let the framework call it during KV cache allocation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy_name: Policy name ("full", "vertical_slash", "streaming_llm", "quest")
|
policy_type: SparsePolicyType enum value
|
||||||
**kwargs: Policy-specific configuration
|
**kwargs: Policy-specific configuration options
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SparsePolicy instance
|
SparsePolicy instance (not initialized)
|
||||||
"""
|
|
||||||
policy_name = policy_name.lower()
|
|
||||||
|
|
||||||
if policy_name == "full":
|
Example:
|
||||||
|
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
|
||||||
|
policy.initialize(num_layers=28, num_kv_heads=8, ...)
|
||||||
|
"""
|
||||||
|
if policy_type == SparsePolicyType.FULL:
|
||||||
return FullAttentionPolicy()
|
return FullAttentionPolicy()
|
||||||
elif policy_name == "vertical_slash":
|
|
||||||
config = VerticalSlashConfig(
|
elif policy_type == SparsePolicyType.QUEST:
|
||||||
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
|
config = QuestConfig(
|
||||||
local_window_blocks=kwargs.get("local_window_blocks", 2),
|
topk_blocks=kwargs.get("topk_blocks", 8),
|
||||||
threshold_blocks=kwargs.get("threshold_blocks", 4),
|
threshold_blocks=kwargs.get("threshold_blocks", 4),
|
||||||
|
include_sink_blocks=kwargs.get("include_sink_blocks", 0),
|
||||||
|
include_recent_blocks=kwargs.get("include_recent_blocks", 0),
|
||||||
)
|
)
|
||||||
return VerticalSlashPolicy(config)
|
return QuestPolicy(config)
|
||||||
elif policy_name == "streaming_llm":
|
|
||||||
config = StreamingLLMConfig(
|
|
||||||
num_sink_blocks=kwargs.get("num_sink_blocks", 1),
|
|
||||||
num_recent_blocks=kwargs.get("num_recent_blocks", 3),
|
|
||||||
)
|
|
||||||
return StreamingLLMPolicy(config)
|
|
||||||
elif policy_name == "quest":
|
|
||||||
# Quest requires metadata_manager to be passed separately
|
|
||||||
raise ValueError(
|
|
||||||
"Quest policy requires BlockMetadataManager. "
|
|
||||||
"Use QuestPolicy(config, metadata_manager) directly."
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(f"Unknown policy type: {policy_type}")
|
||||||
f"Unknown sparse policy '{policy_name}'. "
|
|
||||||
f"Available policies: {list(BUILTIN_SPARSE_POLICIES.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SparsePolicy",
|
"SparsePolicy",
|
||||||
"PolicyContext",
|
"PolicyContext",
|
||||||
|
"SparsePolicyType",
|
||||||
"FullAttentionPolicy",
|
"FullAttentionPolicy",
|
||||||
"VerticalSlashPolicy",
|
|
||||||
"VerticalSlashConfig",
|
|
||||||
"QuestPolicy",
|
"QuestPolicy",
|
||||||
"QuestConfig",
|
"QuestConfig",
|
||||||
"BlockMetadataManager",
|
"BlockMetadataManager",
|
||||||
"StreamingLLMPolicy",
|
"create_sparse_policy",
|
||||||
"StreamingLLMConfig",
|
|
||||||
"HybridPolicy",
|
|
||||||
"get_sparse_policy",
|
|
||||||
"BUILTIN_SPARSE_POLICIES",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -22,6 +22,10 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
- For short sequences where sparsity isn't beneficial
|
- For short sequences where sparsity isn't beneficial
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Full attention supports both prefill and decode
|
||||||
|
supports_prefill = True
|
||||||
|
supports_decode = True
|
||||||
|
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
|
|||||||
@@ -1,93 +0,0 @@
|
|||||||
"""
|
|
||||||
Hybrid sparse attention policy.
|
|
||||||
|
|
||||||
Allows using different policies for prefill vs decode phases.
|
|
||||||
This is useful because optimal sparsity patterns often differ:
|
|
||||||
- Prefill: fixed patterns work well (e.g., VerticalSlash)
|
|
||||||
- Decode: query-aware selection helps (e.g., Quest)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import List
|
|
||||||
import torch
|
|
||||||
from .policy import SparsePolicy, PolicyContext
|
|
||||||
|
|
||||||
|
|
||||||
class HybridPolicy(SparsePolicy):
|
|
||||||
"""
|
|
||||||
Hybrid policy that uses different policies for prefill and decode.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
```python
|
|
||||||
from nanovllm.kvcache.sparse import (
|
|
||||||
HybridPolicy, VerticalSlashPolicy, QuestPolicy,
|
|
||||||
VerticalSlashConfig, QuestConfig, BlockMetadataManager
|
|
||||||
)
|
|
||||||
|
|
||||||
# Prefill: use fast fixed pattern
|
|
||||||
prefill_policy = VerticalSlashPolicy(VerticalSlashConfig(
|
|
||||||
num_sink_blocks=1,
|
|
||||||
local_window_blocks=3,
|
|
||||||
))
|
|
||||||
|
|
||||||
# Decode: use query-aware selection
|
|
||||||
metadata = BlockMetadataManager(num_blocks, num_layers, num_heads, head_dim)
|
|
||||||
decode_policy = QuestPolicy(QuestConfig(topk_blocks=8), metadata)
|
|
||||||
|
|
||||||
# Combine
|
|
||||||
policy = HybridPolicy(prefill_policy, decode_policy)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prefill_policy: SparsePolicy,
|
|
||||||
decode_policy: SparsePolicy,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize hybrid policy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prefill_policy: Policy to use during prefill phase
|
|
||||||
decode_policy: Policy to use during decode phase
|
|
||||||
"""
|
|
||||||
self.prefill_policy = prefill_policy
|
|
||||||
self.decode_policy = decode_policy
|
|
||||||
|
|
||||||
def select_blocks(
|
|
||||||
self,
|
|
||||||
available_blocks: List[int],
|
|
||||||
ctx: PolicyContext,
|
|
||||||
) -> List[int]:
|
|
||||||
"""Delegate to appropriate policy based on phase."""
|
|
||||||
if ctx.is_prefill:
|
|
||||||
return self.prefill_policy.select_blocks(available_blocks, ctx)
|
|
||||||
else:
|
|
||||||
return self.decode_policy.select_blocks(available_blocks, ctx)
|
|
||||||
|
|
||||||
def on_block_offloaded(
|
|
||||||
self,
|
|
||||||
cpu_block_id: int,
|
|
||||||
layer_id: int,
|
|
||||||
k_cache: torch.Tensor,
|
|
||||||
num_valid_tokens: int,
|
|
||||||
) -> None:
|
|
||||||
"""Forward to both policies (both may need metadata updates)."""
|
|
||||||
self.prefill_policy.on_block_offloaded(
|
|
||||||
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
|
||||||
)
|
|
||||||
self.decode_policy.on_block_offloaded(
|
|
||||||
cpu_block_id, layer_id, k_cache, num_valid_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset both policies."""
|
|
||||||
self.prefill_policy.reset()
|
|
||||||
self.decode_policy.reset()
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"HybridPolicy(\n"
|
|
||||||
f" prefill={self.prefill_policy},\n"
|
|
||||||
f" decode={self.decode_policy}\n"
|
|
||||||
f")"
|
|
||||||
)
|
|
||||||
@@ -10,6 +10,9 @@ from dataclasses import dataclass
|
|||||||
from typing import List, Optional, Any
|
from typing import List, Optional, Any
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
# Import SparsePolicyType from config to avoid circular imports
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PolicyContext:
|
class PolicyContext:
|
||||||
@@ -39,7 +42,7 @@ class PolicyContext:
|
|||||||
is_prefill: bool
|
is_prefill: bool
|
||||||
"""True if in prefill phase, False if in decode phase."""
|
"""True if in prefill phase, False if in decode phase."""
|
||||||
|
|
||||||
block_size: int = 4096
|
block_size: int = 1024
|
||||||
"""Number of tokens per block."""
|
"""Number of tokens per block."""
|
||||||
|
|
||||||
total_kv_len: int = 0
|
total_kv_len: int = 0
|
||||||
@@ -54,8 +57,15 @@ class SparsePolicy(ABC):
|
|||||||
sparse attention patterns. The policy receives context about
|
sparse attention patterns. The policy receives context about
|
||||||
the current query chunk and returns which KV blocks to load.
|
the current query chunk and returns which KV blocks to load.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
supports_prefill: Whether this policy can be used for prefill phase.
|
||||||
|
supports_decode: Whether this policy can be used for decode phase.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
class MySparsePolicy(SparsePolicy):
|
class MySparsePolicy(SparsePolicy):
|
||||||
|
supports_prefill = False # decode-only policy
|
||||||
|
supports_decode = True
|
||||||
|
|
||||||
def select_blocks(self, available_blocks, ctx):
|
def select_blocks(self, available_blocks, ctx):
|
||||||
# Load first block and last 2 blocks
|
# Load first block and last 2 blocks
|
||||||
if len(available_blocks) <= 3:
|
if len(available_blocks) <= 3:
|
||||||
@@ -63,6 +73,36 @@ class SparsePolicy(ABC):
|
|||||||
return [available_blocks[0]] + available_blocks[-2:]
|
return [available_blocks[0]] + available_blocks[-2:]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Compatibility flags - override in subclasses
|
||||||
|
supports_prefill: bool = True
|
||||||
|
supports_decode: bool = True
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
num_cpu_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Initialize policy resources.
|
||||||
|
|
||||||
|
Called by the framework after KV cache is allocated. Override this
|
||||||
|
to create metadata structures (e.g., BlockMetadataManager for Quest).
|
||||||
|
Default implementation does nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_layers: Number of transformer layers
|
||||||
|
num_kv_heads: Number of KV attention heads
|
||||||
|
head_dim: Dimension per head
|
||||||
|
num_cpu_blocks: Number of CPU blocks allocated
|
||||||
|
dtype: Data type for tensors
|
||||||
|
device: Device for metadata storage (GPU recommended for performance)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
@@ -90,7 +130,7 @@ class SparsePolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def on_block_offloaded(
|
def on_prefill_offload(
|
||||||
self,
|
self,
|
||||||
cpu_block_id: int,
|
cpu_block_id: int,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
@@ -98,15 +138,38 @@ class SparsePolicy(ABC):
|
|||||||
num_valid_tokens: int,
|
num_valid_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Hook called when a block is offloaded from GPU to CPU.
|
Hook called when a block is offloaded during prefill phase.
|
||||||
|
|
||||||
|
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||||
Override this to collect metadata about blocks (e.g., min/max keys
|
Override this to collect metadata about blocks (e.g., min/max keys
|
||||||
for Quest-style selection). Default implementation does nothing.
|
for Quest-style selection). Default implementation does nothing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cpu_block_id: The CPU block ID that was written
|
cpu_block_id: The CPU block ID that will be written
|
||||||
layer_id: Transformer layer index
|
layer_id: Transformer layer index
|
||||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
|
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||||
|
num_valid_tokens: Number of valid tokens in this block
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_decode_offload(
|
||||||
|
self,
|
||||||
|
cpu_block_id: int,
|
||||||
|
layer_id: int,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
num_valid_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Hook called when a block is offloaded during decode phase.
|
||||||
|
|
||||||
|
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||||
|
Override this to update metadata about blocks. Default implementation
|
||||||
|
does nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cpu_block_id: The CPU block ID that will be written
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||||
num_valid_tokens: Number of valid tokens in this block
|
num_valid_tokens: Number of valid tokens in this block
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class BlockMetadataManager:
|
|||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
dtype: torch.dtype = torch.bfloat16,
|
dtype: torch.dtype = torch.bfloat16,
|
||||||
|
device: torch.device = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize metadata storage.
|
Initialize metadata storage.
|
||||||
@@ -45,20 +46,23 @@ class BlockMetadataManager:
|
|||||||
num_kv_heads: Number of KV attention heads
|
num_kv_heads: Number of KV attention heads
|
||||||
head_dim: Dimension per head
|
head_dim: Dimension per head
|
||||||
dtype: Data type for metadata storage
|
dtype: Data type for metadata storage
|
||||||
|
device: Device for metadata storage (default: CUDA if available)
|
||||||
"""
|
"""
|
||||||
self.num_blocks = num_blocks
|
self.num_blocks = num_blocks
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.num_kv_heads = num_kv_heads
|
self.num_kv_heads = num_kv_heads
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
self.device = device if device is not None else torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
# Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim]
|
# Per-block min/max key values: [num_blocks, num_layers, num_heads, head_dim]
|
||||||
|
# Stored on GPU for efficient score computation during decode
|
||||||
shape = (num_blocks, num_layers, num_kv_heads, head_dim)
|
shape = (num_blocks, num_layers, num_kv_heads, head_dim)
|
||||||
self.key_min = torch.zeros(shape, dtype=dtype, pin_memory=True)
|
self.key_min = torch.zeros(shape, dtype=dtype, device=self.device)
|
||||||
self.key_max = torch.zeros(shape, dtype=dtype, pin_memory=True)
|
self.key_max = torch.zeros(shape, dtype=dtype, device=self.device)
|
||||||
|
|
||||||
# Track which blocks have valid metadata
|
# Track which blocks have valid metadata
|
||||||
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool)
|
self.valid_blocks = torch.zeros(num_blocks, dtype=torch.bool, device=self.device)
|
||||||
|
|
||||||
def update_metadata(
|
def update_metadata(
|
||||||
self,
|
self,
|
||||||
@@ -70,21 +74,21 @@ class BlockMetadataManager:
|
|||||||
"""
|
"""
|
||||||
Update min/max key bounds for a block.
|
Update min/max key bounds for a block.
|
||||||
|
|
||||||
Called when a block is offloaded to CPU.
|
Called BEFORE offload to CPU, while k_cache is still on GPU.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
block_id: CPU block ID
|
block_id: CPU block ID
|
||||||
layer_id: Layer index
|
layer_id: Layer index
|
||||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim]
|
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||||
num_valid_tokens: Number of valid tokens in this block
|
num_valid_tokens: Number of valid tokens in this block
|
||||||
"""
|
"""
|
||||||
if num_valid_tokens == 0:
|
if num_valid_tokens == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get valid keys only
|
# Get valid keys only (k_cache is on GPU, metadata is on GPU)
|
||||||
k_valid = k_cache[:num_valid_tokens].cpu() # [num_tokens, heads, dim]
|
k_valid = k_cache[:num_valid_tokens] # [num_tokens, heads, dim]
|
||||||
|
|
||||||
# Compute min/max across token dimension
|
# Compute min/max across token dimension (all on GPU)
|
||||||
self.key_min[block_id, layer_id] = k_valid.min(dim=0).values
|
self.key_min[block_id, layer_id] = k_valid.min(dim=0).values
|
||||||
self.key_max[block_id, layer_id] = k_valid.max(dim=0).values
|
self.key_max[block_id, layer_id] = k_valid.max(dim=0).values
|
||||||
self.valid_blocks[block_id] = True
|
self.valid_blocks[block_id] = True
|
||||||
@@ -147,22 +151,42 @@ class QuestPolicy(SparsePolicy):
|
|||||||
This upper bound is derived from the fact that for any key k in
|
This upper bound is derived from the fact that for any key k in
|
||||||
the block: min_k <= k <= max_k (element-wise), so the actual
|
the block: min_k <= k <= max_k (element-wise), so the actual
|
||||||
attention score is bounded by the maximum of the two extremes.
|
attention score is bounded by the maximum of the two extremes.
|
||||||
|
|
||||||
|
Note: This is a decode-only policy. For prefill, use FullAttentionPolicy.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
# Quest is decode-only
|
||||||
self,
|
supports_prefill = False
|
||||||
config: QuestConfig,
|
supports_decode = True
|
||||||
metadata_manager: BlockMetadataManager,
|
|
||||||
):
|
def __init__(self, config: QuestConfig):
|
||||||
"""
|
"""
|
||||||
Initialize Quest policy.
|
Initialize Quest policy.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: QuestConfig with selection parameters
|
config: QuestConfig with selection parameters
|
||||||
metadata_manager: BlockMetadataManager for min/max key storage
|
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.metadata = metadata_manager
|
self.metadata: Optional[BlockMetadataManager] = None
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
num_layers: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
num_cpu_blocks: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device = None,
|
||||||
|
) -> None:
|
||||||
|
"""Create BlockMetadataManager for storing min/max keys on GPU."""
|
||||||
|
self.metadata = BlockMetadataManager(
|
||||||
|
num_blocks=num_cpu_blocks,
|
||||||
|
num_layers=num_layers,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
@@ -175,6 +199,12 @@ class QuestPolicy(SparsePolicy):
|
|||||||
If query is not available (some prefill scenarios), falls back
|
If query is not available (some prefill scenarios), falls back
|
||||||
to loading all blocks.
|
to loading all blocks.
|
||||||
"""
|
"""
|
||||||
|
if self.metadata is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"QuestPolicy not initialized. Call initialize() first or "
|
||||||
|
"let the framework call it during KV cache allocation."
|
||||||
|
)
|
||||||
|
|
||||||
n = len(available_blocks)
|
n = len(available_blocks)
|
||||||
|
|
||||||
# If below threshold or no query, load all
|
# If below threshold or no query, load all
|
||||||
@@ -185,15 +215,13 @@ class QuestPolicy(SparsePolicy):
|
|||||||
# No query available - cannot compute scores
|
# No query available - cannot compute scores
|
||||||
return available_blocks
|
return available_blocks
|
||||||
|
|
||||||
# Get metadata for available blocks
|
# Get metadata for available blocks (already on GPU)
|
||||||
key_min, key_max = self.metadata.get_block_metadata(
|
key_min, key_max = self.metadata.get_block_metadata(
|
||||||
available_blocks, ctx.layer_id
|
available_blocks, ctx.layer_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Move to query device for computation
|
# Metadata is already on GPU, same device as query
|
||||||
device = ctx.query.device
|
device = ctx.query.device
|
||||||
key_min = key_min.to(device, non_blocking=True)
|
|
||||||
key_max = key_max.to(device, non_blocking=True)
|
|
||||||
|
|
||||||
# Compute upper bound scores
|
# Compute upper bound scores
|
||||||
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
|
# query shape: [1, num_heads, head_dim] or [1, seq_len, num_heads, head_dim]
|
||||||
@@ -261,19 +289,32 @@ class QuestPolicy(SparsePolicy):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def on_block_offloaded(
|
def on_prefill_offload(
|
||||||
self,
|
self,
|
||||||
cpu_block_id: int,
|
cpu_block_id: int,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
k_cache: torch.Tensor,
|
k_cache: torch.Tensor,
|
||||||
num_valid_tokens: int,
|
num_valid_tokens: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update min/max key metadata when block is offloaded."""
|
"""Update min/max key metadata during prefill offload."""
|
||||||
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
if self.metadata is not None:
|
||||||
|
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||||
|
|
||||||
|
def on_decode_offload(
|
||||||
|
self,
|
||||||
|
cpu_block_id: int,
|
||||||
|
layer_id: int,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
num_valid_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""Update min/max key metadata during decode offload (for new blocks)."""
|
||||||
|
if self.metadata is not None:
|
||||||
|
self.metadata.update_metadata(cpu_block_id, layer_id, k_cache, num_valid_tokens)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset metadata."""
|
"""Reset metadata."""
|
||||||
self.metadata.reset()
|
if self.metadata is not None:
|
||||||
|
self.metadata.reset()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
|
|||||||
@@ -1,84 +0,0 @@
|
|||||||
"""
|
|
||||||
StreamingLLM sparse attention policy.
|
|
||||||
|
|
||||||
Only keeps sink tokens (beginning) + recent tokens (end).
|
|
||||||
Intermediate context is discarded. This enables infinite-length
|
|
||||||
generation but loses intermediate context.
|
|
||||||
|
|
||||||
Reference: StreamingLLM paper on attention sinks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List
|
|
||||||
from .policy import SparsePolicy, PolicyContext
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StreamingLLMConfig:
|
|
||||||
"""Configuration for StreamingLLMPolicy."""
|
|
||||||
|
|
||||||
num_sink_blocks: int = 1
|
|
||||||
"""Number of blocks at the beginning to always include (attention sinks)."""
|
|
||||||
|
|
||||||
num_recent_blocks: int = 3
|
|
||||||
"""Number of most recent blocks to include (sliding window)."""
|
|
||||||
|
|
||||||
|
|
||||||
class StreamingLLMPolicy(SparsePolicy):
|
|
||||||
"""
|
|
||||||
StreamingLLM pattern: sink tokens + recent tokens only.
|
|
||||||
|
|
||||||
This is the most aggressive sparsity pattern - only keeps a small
|
|
||||||
fixed window of context. Suitable for:
|
|
||||||
- Very long streaming generation
|
|
||||||
- When intermediate context can be safely discarded
|
|
||||||
- Maximizing throughput over accuracy
|
|
||||||
|
|
||||||
Pattern visualization:
|
|
||||||
```
|
|
||||||
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
|
|
||||||
↑ × × × ↑ ↑ ↑
|
|
||||||
sink (discarded) recent window
|
|
||||||
```
|
|
||||||
|
|
||||||
Warning: This loses information from intermediate blocks!
|
|
||||||
Use only when this trade-off is acceptable.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: StreamingLLMConfig = None):
|
|
||||||
self.config = config or StreamingLLMConfig()
|
|
||||||
|
|
||||||
def select_blocks(
|
|
||||||
self,
|
|
||||||
available_blocks: List[int],
|
|
||||||
ctx: PolicyContext,
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
|
||||||
Select sink blocks + recent blocks only.
|
|
||||||
|
|
||||||
Intermediate blocks are not loaded (effectively discarded).
|
|
||||||
"""
|
|
||||||
n = len(available_blocks)
|
|
||||||
|
|
||||||
# If total blocks fit in sink + recent, load all
|
|
||||||
total_keep = self.config.num_sink_blocks + self.config.num_recent_blocks
|
|
||||||
if n <= total_keep:
|
|
||||||
return available_blocks
|
|
||||||
|
|
||||||
selected_indices = set()
|
|
||||||
|
|
||||||
# Sink blocks (first N)
|
|
||||||
for i in range(min(self.config.num_sink_blocks, n)):
|
|
||||||
selected_indices.add(i)
|
|
||||||
|
|
||||||
# Recent blocks (last M)
|
|
||||||
for i in range(max(0, n - self.config.num_recent_blocks), n):
|
|
||||||
selected_indices.add(i)
|
|
||||||
|
|
||||||
return [available_blocks[i] for i in sorted(selected_indices)]
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"StreamingLLMPolicy(sink={self.config.num_sink_blocks}, "
|
|
||||||
f"recent={self.config.num_recent_blocks})"
|
|
||||||
)
|
|
||||||
@@ -1,95 +0,0 @@
|
|||||||
"""
|
|
||||||
Vertical-Slash sparse attention policy (MInference-style).
|
|
||||||
|
|
||||||
Selects sink blocks (beginning of sequence) + local window blocks
|
|
||||||
(near the current query position). This pattern captures:
|
|
||||||
- Important initial context (system prompt, instructions)
|
|
||||||
- Recent context (relevant for local dependencies)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import List
|
|
||||||
from .policy import SparsePolicy, PolicyContext
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VerticalSlashConfig:
|
|
||||||
"""Configuration for VerticalSlashPolicy."""
|
|
||||||
|
|
||||||
num_sink_blocks: int = 1
|
|
||||||
"""Number of blocks at the beginning to always include (sink tokens)."""
|
|
||||||
|
|
||||||
local_window_blocks: int = 2
|
|
||||||
"""Number of blocks in the local window near current query position."""
|
|
||||||
|
|
||||||
threshold_blocks: int = 4
|
|
||||||
"""If total blocks <= threshold, load all (no sparsity applied)."""
|
|
||||||
|
|
||||||
|
|
||||||
class VerticalSlashPolicy(SparsePolicy):
|
|
||||||
"""
|
|
||||||
Vertical-Slash pattern: sink tokens + local window.
|
|
||||||
|
|
||||||
This pattern is inspired by MInference and observations that:
|
|
||||||
1. Initial tokens (sink) often receive high attention
|
|
||||||
2. Local context (recent tokens) is important for dependencies
|
|
||||||
|
|
||||||
Pattern visualization:
|
|
||||||
```
|
|
||||||
Blocks: [0] [1] [2] [3] [4] [5] [6] [7] [8]
|
|
||||||
↑ ↑ ↑ ↑
|
|
||||||
sink local window (for query at block 9)
|
|
||||||
```
|
|
||||||
|
|
||||||
For prefill chunk K, the local window is blocks [K-window, K-1].
|
|
||||||
For decode, the local window is the last N blocks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, config: VerticalSlashConfig = None):
|
|
||||||
self.config = config or VerticalSlashConfig()
|
|
||||||
|
|
||||||
def select_blocks(
|
|
||||||
self,
|
|
||||||
available_blocks: List[int],
|
|
||||||
ctx: PolicyContext,
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
|
||||||
Select sink blocks + local window blocks.
|
|
||||||
|
|
||||||
For prefill: local window is relative to current chunk position.
|
|
||||||
For decode: local window is the most recent blocks.
|
|
||||||
"""
|
|
||||||
n = len(available_blocks)
|
|
||||||
|
|
||||||
# If below threshold, load all
|
|
||||||
if n <= self.config.threshold_blocks:
|
|
||||||
return available_blocks
|
|
||||||
|
|
||||||
selected_indices = set()
|
|
||||||
|
|
||||||
# Sink blocks (first N blocks)
|
|
||||||
for i in range(min(self.config.num_sink_blocks, n)):
|
|
||||||
selected_indices.add(i)
|
|
||||||
|
|
||||||
# Local window
|
|
||||||
if ctx.is_prefill:
|
|
||||||
# For prefill chunk K, local window is blocks [K-window, K-1]
|
|
||||||
# (blocks before current chunk, not including current)
|
|
||||||
window_end = min(ctx.query_chunk_idx, n)
|
|
||||||
window_start = max(0, window_end - self.config.local_window_blocks)
|
|
||||||
for i in range(window_start, window_end):
|
|
||||||
selected_indices.add(i)
|
|
||||||
else:
|
|
||||||
# For decode, local window is the last M blocks
|
|
||||||
for i in range(max(0, n - self.config.local_window_blocks), n):
|
|
||||||
selected_indices.add(i)
|
|
||||||
|
|
||||||
# Return blocks in order (maintains sequential access pattern)
|
|
||||||
return [available_blocks[i] for i in sorted(selected_indices)]
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (
|
|
||||||
f"VerticalSlashPolicy(sink={self.config.num_sink_blocks}, "
|
|
||||||
f"window={self.config.local_window_blocks}, "
|
|
||||||
f"threshold={self.config.threshold_blocks})"
|
|
||||||
)
|
|
||||||
@@ -2,8 +2,6 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
import torch.cuda.nvtx
|
import torch.cuda.nvtx
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
@@ -12,37 +10,59 @@ from nanovllm.kvcache.sparse.policy import PolicyContext
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
def store_kvcache(
|
||||||
def store_kvcache_kernel(
|
key: torch.Tensor,
|
||||||
key_ptr,
|
value: torch.Tensor,
|
||||||
key_stride,
|
k_cache: torch.Tensor,
|
||||||
value_ptr,
|
v_cache: torch.Tensor,
|
||||||
value_stride,
|
slot_mapping: torch.Tensor,
|
||||||
k_cache_ptr,
|
|
||||||
v_cache_ptr,
|
|
||||||
slot_mapping_ptr,
|
|
||||||
D: tl.constexpr,
|
|
||||||
):
|
):
|
||||||
idx = tl.program_id(0)
|
"""
|
||||||
slot = tl.load(slot_mapping_ptr + idx)
|
Store key/value tensors into KV cache using slot mapping.
|
||||||
if slot == -1: return
|
|
||||||
key_offsets = idx * key_stride + tl.arange(0, D)
|
|
||||||
value_offsets = idx * value_stride + tl.arange(0, D)
|
|
||||||
key = tl.load(key_ptr + key_offsets)
|
|
||||||
value = tl.load(value_ptr + value_offsets)
|
|
||||||
cache_offsets = slot * D + tl.arange(0, D)
|
|
||||||
tl.store(k_cache_ptr + cache_offsets, key)
|
|
||||||
tl.store(v_cache_ptr + cache_offsets, value)
|
|
||||||
|
|
||||||
|
This is a pure PyTorch implementation replacing the previous Triton kernel.
|
||||||
|
Uses index_copy_ for efficient in-place scatter operation.
|
||||||
|
|
||||||
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
Args:
|
||||||
N, num_heads, head_dim = key.shape
|
key: [N, num_kv_heads, head_dim]
|
||||||
D = num_heads * head_dim
|
value: [N, num_kv_heads, head_dim]
|
||||||
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] or similar
|
||||||
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
v_cache: same shape as k_cache
|
||||||
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
slot_mapping: [N] with values as flat indices, -1 means skip
|
||||||
assert slot_mapping.numel() == N
|
"""
|
||||||
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
is_capturing = torch.cuda.is_current_stream_capturing()
|
||||||
|
|
||||||
|
if is_capturing:
|
||||||
|
# During CUDA graph capture, assume all slots are valid.
|
||||||
|
# CUDA graphs don't support data-dependent operations like boolean indexing.
|
||||||
|
# This is safe because decode (captured) always has valid slots.
|
||||||
|
valid_slots = slot_mapping
|
||||||
|
valid_keys = key
|
||||||
|
valid_values = value
|
||||||
|
else:
|
||||||
|
# Normal execution: filter out invalid slots (slot == -1)
|
||||||
|
valid_mask = slot_mapping >= 0
|
||||||
|
if not valid_mask.any():
|
||||||
|
return
|
||||||
|
valid_slots = slot_mapping[valid_mask]
|
||||||
|
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
|
||||||
|
valid_values = value[valid_mask]
|
||||||
|
|
||||||
|
# Flatten cache and KV for scatter operation
|
||||||
|
# Cache is viewed as [total_slots, D] where D = num_kv_heads * head_dim
|
||||||
|
N, num_kv_heads, head_dim = key.shape
|
||||||
|
D = num_kv_heads * head_dim
|
||||||
|
total_slots = k_cache.numel() // D
|
||||||
|
|
||||||
|
k_cache_flat = k_cache.view(total_slots, D)
|
||||||
|
v_cache_flat = v_cache.view(total_slots, D)
|
||||||
|
valid_keys_flat = valid_keys.reshape(-1, D)
|
||||||
|
valid_values_flat = valid_values.reshape(-1, D)
|
||||||
|
|
||||||
|
# In-place scatter using index_copy_
|
||||||
|
# 即使 valid_slots 为空张量,index_copy_ 也是安全的(不会修改数据)。
|
||||||
|
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
|
||||||
|
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@@ -66,8 +86,49 @@ class Attention(nn.Module):
|
|||||||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||||||
context = get_context()
|
context = get_context()
|
||||||
k_cache, v_cache = self.k_cache, self.v_cache
|
k_cache, v_cache = self.k_cache, self.v_cache
|
||||||
if k_cache.numel() and v_cache.numel():
|
|
||||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
# Determine if we're in chunked offload mode
|
||||||
|
is_chunked_offload = (
|
||||||
|
context.is_chunked_prefill and
|
||||||
|
hasattr(context, 'kvcache_manager') and
|
||||||
|
context.kvcache_manager is not None and
|
||||||
|
hasattr(context.kvcache_manager, 'offload_engine')
|
||||||
|
)
|
||||||
|
|
||||||
|
#! Ensure synchronization before accessing k_cache/v_cache
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
#! =======================================================
|
||||||
|
|
||||||
|
if is_chunked_offload and context.is_prefill:
|
||||||
|
# Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot)
|
||||||
|
# This enables fully async offloads since each layer has its own buffer.
|
||||||
|
offload_engine = context.kvcache_manager.offload_engine
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||||
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
|
||||||
|
# k, v shape: [num_tokens, kv_heads, head_dim]
|
||||||
|
num_tokens = k.shape[0]
|
||||||
|
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
||||||
|
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
||||||
|
elif is_chunked_offload:
|
||||||
|
# Chunked decode mode: use compute_stream for store_kvcache
|
||||||
|
# This ensures proper synchronization with per-layer offload
|
||||||
|
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
||||||
|
if k_cache.numel() and v_cache.numel():
|
||||||
|
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||||
|
# slot_mapping is created with non_blocking=True on default stream, but we use it
|
||||||
|
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
|
||||||
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
|
else:
|
||||||
|
# Normal mode: store on default stream
|
||||||
|
if k_cache.numel() and v_cache.numel():
|
||||||
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
|
|
||||||
if context.is_prefill:
|
if context.is_prefill:
|
||||||
if context.is_chunked_prefill:
|
if context.is_chunked_prefill:
|
||||||
@@ -111,43 +172,44 @@ class Attention(nn.Module):
|
|||||||
context,
|
context,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute attention with unified ring buffer for chunked prefill.
|
Compute attention with per-layer prefill buffer for async offload.
|
||||||
|
|
||||||
Ring buffer design:
|
Optimized design:
|
||||||
- Current chunk's KV is written to ring_slot[chunk_idx % N]
|
- Current chunk's KV is written to per-layer prefill buffer (not GPU slot)
|
||||||
- Previous chunks' KV are loaded from CPU using N-1 available slots
|
- Previous chunks' KV are loaded from CPU using GPU slots
|
||||||
- Pipeline: pre-fill slots, then process with overlapped load/compute
|
- Each layer offloads from its own buffer - no waiting required!
|
||||||
|
|
||||||
For each layer:
|
For each layer:
|
||||||
1. Current chunk's KV is in k_batched, v_batched (just written by model)
|
1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model)
|
||||||
2. Load previous chunks from CPU using available slots (pipeline)
|
2. Load previous chunks from CPU using available slots (pipeline)
|
||||||
3. Compute attention against previous KV (no causal mask)
|
3. Compute attention against previous KV (no causal mask)
|
||||||
4. Compute attention against current KV (causal)
|
4. Compute attention against current KV from prefill buffer (causal)
|
||||||
5. Merge all results using online softmax
|
5. Merge all results using online softmax
|
||||||
|
6. Async offload prefill buffer to CPU (no waiting!)
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
current_chunk_idx = context.current_chunk_idx
|
current_chunk_idx = context.current_chunk_idx
|
||||||
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
||||||
|
|
||||||
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
# q shape: [total_tokens, num_heads, head_dim]
|
||||||
# Reshape for flash attention: [batch, seq, heads, dim]
|
|
||||||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||||||
k_batched = k.unsqueeze(0)
|
num_tokens = k.shape[0]
|
||||||
v_batched = v.unsqueeze(0)
|
|
||||||
|
|
||||||
o_acc = None
|
o_acc = None
|
||||||
lse_acc = None
|
lse_acc = None
|
||||||
|
|
||||||
kvcache_manager = context.kvcache_manager
|
kvcache_manager = context.kvcache_manager
|
||||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||||
|
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
|
||||||
|
|
||||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
# Apply sparse policy if enabled
|
# Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None)
|
||||||
if cpu_block_table and kvcache_manager.sparse_policy is not None:
|
sparse_policy = kvcache_manager.sparse_policy
|
||||||
|
if cpu_block_table and sparse_policy is not None:
|
||||||
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||||
policy_ctx = PolicyContext(
|
policy_ctx = PolicyContext(
|
||||||
query_chunk_idx=current_chunk_idx,
|
query_chunk_idx=current_chunk_idx,
|
||||||
@@ -158,16 +220,13 @@ class Attention(nn.Module):
|
|||||||
block_size=kvcache_manager.block_size,
|
block_size=kvcache_manager.block_size,
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
)
|
)
|
||||||
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
cpu_block_table = sparse_policy.select_blocks(
|
||||||
cpu_block_table, policy_ctx
|
cpu_block_table, policy_ctx
|
||||||
)
|
)
|
||||||
|
|
||||||
if cpu_block_table:
|
if cpu_block_table:
|
||||||
offload_engine = kvcache_manager.offload_engine
|
# Get available load slots (all slots can be used since we use prefill buffer)
|
||||||
|
load_slots = list(range(offload_engine.num_ring_slots))
|
||||||
# Get write slot for current chunk and available load slots
|
|
||||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
|
||||||
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
|
||||||
pipeline_depth = len(load_slots)
|
pipeline_depth = len(load_slots)
|
||||||
|
|
||||||
if pipeline_depth == 0:
|
if pipeline_depth == 0:
|
||||||
@@ -182,45 +241,67 @@ class Attention(nn.Module):
|
|||||||
current_chunk_idx
|
current_chunk_idx
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get compute stream for all attention operations
|
||||||
|
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
|
||||||
|
|
||||||
# Compute attention against current chunk's KV (with causal mask)
|
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
if compute_stream is not None:
|
||||||
current_o, current_lse = flash_attn_with_lse(
|
with torch.cuda.stream(compute_stream):
|
||||||
q_batched,
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||||
k_batched,
|
# Get KV from per-layer prefill buffer
|
||||||
v_batched,
|
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
|
||||||
softmax_scale=self.scale,
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
causal=True,
|
q_batched,
|
||||||
)
|
k_batched,
|
||||||
torch.cuda.nvtx.range_pop()
|
v_batched,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
else:
|
||||||
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||||
|
k_batched = k.unsqueeze(0)
|
||||||
|
v_batched = v.unsqueeze(0)
|
||||||
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
|
q_batched,
|
||||||
|
k_batched,
|
||||||
|
v_batched,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
# Merge with accumulated
|
# Merge with accumulated (all on compute_stream for consistency)
|
||||||
if o_acc is None:
|
if o_acc is None:
|
||||||
final_o = current_o
|
final_o = current_o
|
||||||
else:
|
else:
|
||||||
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
|
if compute_stream is not None:
|
||||||
# reading it on the default stream for the merge operation.
|
with torch.cuda.stream(compute_stream):
|
||||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||||
offload_engine = kvcache_manager.offload_engine
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||||
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
|
torch.cuda.nvtx.range_pop()
|
||||||
|
else:
|
||||||
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||||
|
|
||||||
# Per-layer offload: In new GPU cache architecture (no layer dimension),
|
# Per-layer ASYNC offload: offload prefill buffer to CPU
|
||||||
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
|
# No waiting required! Each layer has its own buffer and stream.
|
||||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
if offload_engine is not None and seq is not None:
|
||||||
offload_engine = kvcache_manager.offload_engine
|
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
if current_chunk_idx < len(cpu_block_ids):
|
||||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||||
if seq is not None:
|
# Async offload - no waiting, fully parallel across layers
|
||||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
offload_engine.offload_prefill_buffer_async(
|
||||||
if current_chunk_idx < len(cpu_block_ids):
|
self.layer_id, cpu_block_id, num_tokens
|
||||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
)
|
||||||
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
|
|
||||||
|
# Sync default stream with compute_stream before returning
|
||||||
|
# This ensures the result is ready for the rest of the model (layernorm, MLP)
|
||||||
|
if compute_stream is not None:
|
||||||
|
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||||
|
|
||||||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||||
return final_o.squeeze(0)
|
return final_o.squeeze(0)
|
||||||
@@ -318,6 +399,7 @@ class Attention(nn.Module):
|
|||||||
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
||||||
|
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
||||||
|
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
q_batched, prev_k, prev_v,
|
q_batched, prev_k, prev_v,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
@@ -364,6 +446,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||||
|
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
q_batched, prev_k, prev_v,
|
q_batched, prev_k, prev_v,
|
||||||
softmax_scale=self.scale,
|
softmax_scale=self.scale,
|
||||||
@@ -399,17 +482,15 @@ class Attention(nn.Module):
|
|||||||
context,
|
context,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute decode attention using ring buffer pipeline (same as prefill).
|
Compute decode attention using cross-layer pipeline.
|
||||||
|
|
||||||
Uses the same loading mechanism as _chunked_prefill_attention:
|
Optimization: Uses double-buffered layer cache to overlap H2D transfer
|
||||||
- Load one block at a time from CPU to GPU slot
|
with computation across layers:
|
||||||
- Compute attention for each block
|
- Layer N computes while Layer N+1's data is being loaded
|
||||||
- Merge results using online softmax
|
- Each layer only waits for its own data, not all layers' data
|
||||||
- Finally merge with decode buffer (accumulated decode tokens)
|
|
||||||
|
|
||||||
This approach is simpler and proven correct (prefill tests pass).
|
This reduces effective latency from O(num_layers * transfer_time) to
|
||||||
The only difference from prefill is the additional decode buffer
|
O(transfer_time + num_layers * compute_time) when transfer < compute.
|
||||||
that stores new tokens generated during decode.
|
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
@@ -426,16 +507,19 @@ class Attention(nn.Module):
|
|||||||
if not cpu_block_table:
|
if not cpu_block_table:
|
||||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||||
|
|
||||||
# Calculate valid tokens in the last block
|
# Calculate valid tokens in the last CPU block
|
||||||
# Note: For chunked prefill, each block is exactly block_size tokens
|
# CRITICAL: Use original prefill length, not current seq length!
|
||||||
# The cpu_block_table only contains full prefill blocks
|
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||||
block_size = kvcache_manager.block_size
|
block_size = kvcache_manager.block_size
|
||||||
num_prefill_blocks = len(cpu_block_table)
|
num_prefill_blocks = len(cpu_block_table)
|
||||||
# All prefill blocks are full (block_size tokens each)
|
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||||
last_block_valid_tokens = block_size
|
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||||
|
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||||
|
last_block_valid_tokens = block_size # Last block was exactly full
|
||||||
|
|
||||||
# Apply sparse policy if enabled
|
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
|
||||||
if kvcache_manager.sparse_policy is not None:
|
sparse_policy = kvcache_manager.sparse_policy
|
||||||
|
if sparse_policy is not None:
|
||||||
policy_ctx = PolicyContext(
|
policy_ctx = PolicyContext(
|
||||||
query_chunk_idx=0,
|
query_chunk_idx=0,
|
||||||
num_query_chunks=1,
|
num_query_chunks=1,
|
||||||
@@ -445,18 +529,25 @@ class Attention(nn.Module):
|
|||||||
block_size=kvcache_manager.block_size,
|
block_size=kvcache_manager.block_size,
|
||||||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
)
|
)
|
||||||
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
|
cpu_block_table = sparse_policy.select_blocks(
|
||||||
cpu_block_table, policy_ctx
|
cpu_block_table, policy_ctx
|
||||||
)
|
)
|
||||||
|
|
||||||
offload_engine = kvcache_manager.offload_engine
|
offload_engine = kvcache_manager.offload_engine
|
||||||
load_slots = offload_engine.decode_load_slots # Available slots for loading
|
|
||||||
|
|
||||||
# Use ring buffer pipeline (same as prefill) to load prefilled blocks
|
# Use cross-layer pipeline if active (initialized in model_runner)
|
||||||
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
if offload_engine.is_pipeline_active():
|
||||||
q_batched, cpu_block_table, load_slots, offload_engine,
|
o_acc, lse_acc = self._decode_with_layer_pipeline(
|
||||||
block_size, last_block_valid_tokens
|
q_batched, cpu_block_table, offload_engine,
|
||||||
)
|
block_size, last_block_valid_tokens
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to original ring buffer pipeline
|
||||||
|
load_slots = offload_engine.decode_load_slots
|
||||||
|
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||||
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||||
|
block_size, last_block_valid_tokens
|
||||||
|
)
|
||||||
|
|
||||||
# Now attend to accumulated decode tokens from per-layer decode buffer
|
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||||
pos_in_block = context.decode_pos_in_block
|
pos_in_block = context.decode_pos_in_block
|
||||||
@@ -569,3 +660,62 @@ class Attention(nn.Module):
|
|||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
return o_acc, lse_acc
|
return o_acc, lse_acc
|
||||||
|
|
||||||
|
def _decode_with_layer_pipeline(
|
||||||
|
self,
|
||||||
|
q_batched: torch.Tensor,
|
||||||
|
cpu_block_table: list,
|
||||||
|
offload_engine,
|
||||||
|
block_size: int,
|
||||||
|
last_block_valid_tokens: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decode using cross-layer pipeline for optimized H2D transfer.
|
||||||
|
|
||||||
|
This method uses pre-loaded layer buffers instead of loading
|
||||||
|
blocks one by one. The pipeline loads the next layer's data
|
||||||
|
while the current layer computes, achieving transfer/compute overlap.
|
||||||
|
|
||||||
|
The key insight is that each layer needs the SAME blocks but from
|
||||||
|
different layers of CPU cache. By double-buffering and pipelining
|
||||||
|
across layers, we reduce total latency.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
num_blocks = len(cpu_block_table)
|
||||||
|
if num_blocks == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Get KV from pre-loaded layer buffer (triggers next layer loading)
|
||||||
|
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
|
||||||
|
|
||||||
|
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
|
||||||
|
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
|
||||||
|
total_tokens = num_blocks * block_size
|
||||||
|
|
||||||
|
# Handle partial last block
|
||||||
|
if last_block_valid_tokens < block_size:
|
||||||
|
# Only use valid tokens from last block
|
||||||
|
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
|
||||||
|
# Flatten and truncate
|
||||||
|
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
|
||||||
|
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
|
||||||
|
else:
|
||||||
|
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
|
||||||
|
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
|
||||||
|
|
||||||
|
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
|
||||||
|
prev_k_batched = prev_k_flat.unsqueeze(0)
|
||||||
|
prev_v_batched = prev_v_flat.unsqueeze(0)
|
||||||
|
|
||||||
|
# Compute attention on all prefilled blocks at once
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
o_acc, lse_acc = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k_batched, prev_v_batched,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|||||||
757
tests/modeling_qwen3.py
Normal file
757
tests/modeling_qwen3.py
Normal file
@@ -0,0 +1,757 @@
|
|||||||
|
"""
|
||||||
|
Custom Qwen3 implementation using only torch and transformers.
|
||||||
|
This file provides a clean reference implementation for understanding the model computation graph.
|
||||||
|
|
||||||
|
Computation Graph:
|
||||||
|
==================
|
||||||
|
|
||||||
|
Input: token_ids [batch, seq_len]
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────┐
|
||||||
|
│ Embedding │ embed_tokens: [vocab_size, hidden_size]
|
||||||
|
└─────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
hidden_states [batch, seq_len, hidden_size]
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────────────────────────────────────────────────┐
|
||||||
|
│ Decoder Layer (x N) │
|
||||||
|
│ ┌───────────────────────────────────────────────────┐ │
|
||||||
|
│ │ Self Attention Block │ │
|
||||||
|
│ │ │ │
|
||||||
|
│ │ input_layernorm (RMSNorm) │ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ │ ▼ │ │
|
||||||
|
│ │ ┌─────────────────────────────────────────────┐ │ │
|
||||||
|
│ │ │ Qwen3Attention │ │ │
|
||||||
|
│ │ │ Q = q_proj(x) → q_norm → reshape │ │ │
|
||||||
|
│ │ │ K = k_proj(x) → k_norm → reshape │ │ │
|
||||||
|
│ │ │ V = v_proj(x) → reshape │ │ │
|
||||||
|
│ │ │ │ │ │ │
|
||||||
|
│ │ │ ▼ │ │ │
|
||||||
|
│ │ │ Q, K = apply_rotary_pos_emb(Q, K, cos, sin)│ │ │
|
||||||
|
│ │ │ │ │ │ │
|
||||||
|
│ │ │ ▼ │ │ │
|
||||||
|
│ │ │ attn_output = attention(Q, K, V) │ │ │
|
||||||
|
│ │ │ │ │ │ │
|
||||||
|
│ │ │ ▼ │ │ │
|
||||||
|
│ │ │ output = o_proj(attn_output) │ │ │
|
||||||
|
│ │ └─────────────────────────────────────────────┘ │ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ │ ▼ │ │
|
||||||
|
│ │ hidden_states = residual + attn_output │ │
|
||||||
|
│ └───────────────────────────────────────────────────┘ │
|
||||||
|
│ │ │
|
||||||
|
│ ▼ │
|
||||||
|
│ ┌───────────────────────────────────────────────────┐ │
|
||||||
|
│ │ MLP Block │ │
|
||||||
|
│ │ │ │
|
||||||
|
│ │ post_attention_layernorm (RMSNorm) │ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ │ ▼ │ │
|
||||||
|
│ │ ┌─────────────────────────────────────────────┐ │ │
|
||||||
|
│ │ │ Qwen3MLP │ │ │
|
||||||
|
│ │ │ gate = gate_proj(x) │ │ │
|
||||||
|
│ │ │ up = up_proj(x) │ │ │
|
||||||
|
│ │ │ output = down_proj(silu(gate) * up) │ │ │
|
||||||
|
│ │ └─────────────────────────────────────────────┘ │ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ │ ▼ │ │
|
||||||
|
│ │ hidden_states = residual + mlp_output │ │
|
||||||
|
│ └───────────────────────────────────────────────────┘ │
|
||||||
|
└─────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────┐
|
||||||
|
│ norm │ final RMSNorm
|
||||||
|
└─────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
┌─────────────┐
|
||||||
|
│ lm_head │ [hidden_size, vocab_size]
|
||||||
|
└─────────────┘
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
logits [batch, seq_len, vocab_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
from typing import Optional, Tuple, List
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3RMSNorm(nn.Module):
|
||||||
|
"""RMSNorm implementation."""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
input_dtype = x.dtype
|
||||||
|
x = x.float()
|
||||||
|
variance = x.pow(2).mean(-1, keepdim=True)
|
||||||
|
x = x * torch.rsqrt(variance + self.eps)
|
||||||
|
return self.weight * x.to(input_dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3RotaryEmbedding(nn.Module):
|
||||||
|
"""Rotary Position Embedding (RoPE)."""
|
||||||
|
|
||||||
|
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: float = 10000.0):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.base = base
|
||||||
|
|
||||||
|
# Compute inverse frequencies
|
||||||
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def forward(self, x: torch.Tensor, position_ids: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x: Input tensor [batch, seq_len, num_heads, head_dim] or similar
|
||||||
|
position_ids: Position indices [batch, seq_len]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
cos, sin: [batch, seq_len, head_dim]
|
||||||
|
"""
|
||||||
|
# inv_freq: [dim/2]
|
||||||
|
# position_ids: [batch, seq_len]
|
||||||
|
inv_freq_expanded = self.inv_freq[None, :, None].float() # [1, dim/2, 1]
|
||||||
|
position_ids_expanded = position_ids[:, None, :].float() # [batch, 1, seq_len]
|
||||||
|
|
||||||
|
# freqs: [batch, dim/2, seq_len]
|
||||||
|
freqs = inv_freq_expanded @ position_ids_expanded
|
||||||
|
# freqs: [batch, seq_len, dim/2]
|
||||||
|
freqs = freqs.transpose(1, 2)
|
||||||
|
|
||||||
|
# Duplicate for full head_dim: [batch, seq_len, dim]
|
||||||
|
emb = torch.cat((freqs, freqs), dim=-1)
|
||||||
|
|
||||||
|
cos = emb.cos().to(x.dtype)
|
||||||
|
sin = emb.sin().to(x.dtype)
|
||||||
|
|
||||||
|
return cos, sin
|
||||||
|
|
||||||
|
|
||||||
|
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Rotate half the hidden dims of the input."""
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_rotary_pos_emb(
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
cos: torch.Tensor,
|
||||||
|
sin: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Apply rotary position embeddings to Q and K.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: [batch, num_heads, seq_len, head_dim]
|
||||||
|
k: [batch, num_kv_heads, seq_len, head_dim]
|
||||||
|
cos: [batch, seq_len, head_dim]
|
||||||
|
sin: [batch, seq_len, head_dim]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
q_embed, k_embed with same shapes as inputs
|
||||||
|
"""
|
||||||
|
# Unsqueeze for broadcasting: [batch, 1, seq_len, head_dim]
|
||||||
|
cos = cos.unsqueeze(1)
|
||||||
|
sin = sin.unsqueeze(1)
|
||||||
|
|
||||||
|
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||||
|
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||||
|
|
||||||
|
return q_embed, k_embed
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3Attention(nn.Module):
|
||||||
|
"""
|
||||||
|
Qwen3 Multi-Head Attention with Grouped Query Attention (GQA) support.
|
||||||
|
|
||||||
|
Data Flow:
|
||||||
|
---------
|
||||||
|
hidden_states [batch, seq_len, hidden_size]
|
||||||
|
│
|
||||||
|
├──► q_proj ──► q_norm ──► reshape ──► Q [batch, num_heads, seq_len, head_dim]
|
||||||
|
├──► k_proj ──► k_norm ──► reshape ──► K [batch, num_kv_heads, seq_len, head_dim]
|
||||||
|
└──► v_proj ──► reshape ──► V [batch, num_kv_heads, seq_len, head_dim]
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
apply_rotary_pos_emb(Q, K)
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
attention(Q, K, V) ──► attn_output [batch, num_heads, seq_len, head_dim]
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
reshape ──► o_proj ──► output [batch, seq_len, hidden_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
layer_idx: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.num_heads = num_attention_heads
|
||||||
|
self.num_kv_heads = num_key_value_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.num_kv_groups = num_attention_heads // num_key_value_heads
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
# Scaling factor
|
||||||
|
self.scaling = head_dim ** -0.5
|
||||||
|
|
||||||
|
# QKV projections
|
||||||
|
self.q_proj = nn.Linear(hidden_size, num_attention_heads * head_dim, bias=attention_bias)
|
||||||
|
self.k_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||||
|
self.v_proj = nn.Linear(hidden_size, num_key_value_heads * head_dim, bias=attention_bias)
|
||||||
|
self.o_proj = nn.Linear(num_attention_heads * head_dim, hidden_size, bias=attention_bias)
|
||||||
|
|
||||||
|
# QK normalization (Qwen3 specific)
|
||||||
|
self.q_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
||||||
|
self.k_norm = Qwen3RMSNorm(head_dim, eps=rms_norm_eps)
|
||||||
|
|
||||||
|
# Rotary embeddings
|
||||||
|
self.rotary_emb = Qwen3RotaryEmbedding(
|
||||||
|
head_dim,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
output_qkv: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states: [batch, seq_len, hidden_size]
|
||||||
|
position_ids: [batch, seq_len]
|
||||||
|
attention_mask: [batch, 1, seq_len, kv_seq_len] (causal mask)
|
||||||
|
past_key_value: (k_cache, v_cache) from previous steps
|
||||||
|
use_cache: Whether to return updated cache
|
||||||
|
output_qkv: Whether to output Q, K, V tensors for debugging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
output: [batch, seq_len, hidden_size]
|
||||||
|
past_key_value: Updated cache (if use_cache=True)
|
||||||
|
qkv_dict: {"q": Q, "k": K, "v": V} (if output_qkv=True)
|
||||||
|
"""
|
||||||
|
batch_size, seq_len, _ = hidden_states.shape
|
||||||
|
|
||||||
|
# === QKV Projections ===
|
||||||
|
q = self.q_proj(hidden_states) # [batch, seq_len, num_heads * head_dim]
|
||||||
|
k = self.k_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
||||||
|
v = self.v_proj(hidden_states) # [batch, seq_len, num_kv_heads * head_dim]
|
||||||
|
|
||||||
|
# Reshape to [batch, seq_len, num_heads, head_dim]
|
||||||
|
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||||
|
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim)
|
||||||
|
|
||||||
|
# === QK Normalization (Qwen3 specific) ===
|
||||||
|
q = self.q_norm(q)
|
||||||
|
k = self.k_norm(k)
|
||||||
|
|
||||||
|
# Transpose to [batch, num_heads, seq_len, head_dim]
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
# === Rotary Position Embeddings ===
|
||||||
|
cos, sin = self.rotary_emb(v, position_ids)
|
||||||
|
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
||||||
|
|
||||||
|
# === KV Cache Update ===
|
||||||
|
if past_key_value is not None:
|
||||||
|
k_cache, v_cache = past_key_value
|
||||||
|
k = torch.cat([k_cache, k], dim=2)
|
||||||
|
v = torch.cat([v_cache, v], dim=2)
|
||||||
|
|
||||||
|
new_past_key_value = (k, v) if use_cache else None
|
||||||
|
|
||||||
|
# === Grouped Query Attention (expand KV heads if needed) ===
|
||||||
|
if self.num_kv_groups > 1:
|
||||||
|
# Repeat KV for each query group
|
||||||
|
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
||||||
|
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
||||||
|
|
||||||
|
# === Attention Computation (using SDPA for memory efficiency) ===
|
||||||
|
# Use PyTorch's scaled_dot_product_attention which can use FlashAttention backend
|
||||||
|
# is_causal only works when q_len == kv_len (prefill), not during decode
|
||||||
|
q_len, kv_len = q.shape[2], k.shape[2]
|
||||||
|
is_causal = (q_len == kv_len) and (q_len > 1)
|
||||||
|
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
q, k, v,
|
||||||
|
attn_mask=None,
|
||||||
|
dropout_p=0.0,
|
||||||
|
is_causal=is_causal,
|
||||||
|
scale=self.scaling,
|
||||||
|
) # [batch, num_heads, seq_len, head_dim]
|
||||||
|
|
||||||
|
# === Output Projection ===
|
||||||
|
# Transpose back and reshape
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous() # [batch, seq_len, num_heads, head_dim]
|
||||||
|
attn_output = attn_output.view(batch_size, seq_len, -1) # [batch, seq_len, hidden_size]
|
||||||
|
output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
# Optional QKV output for debugging
|
||||||
|
qkv_dict = None
|
||||||
|
if output_qkv:
|
||||||
|
qkv_dict = {
|
||||||
|
"q": q, # [batch, num_heads, seq_len, head_dim] (post-RoPE)
|
||||||
|
"k": k, # [batch, num_heads, kv_seq_len, head_dim] (post-RoPE, expanded)
|
||||||
|
"v": v, # [batch, num_heads, kv_seq_len, head_dim] (expanded)
|
||||||
|
}
|
||||||
|
|
||||||
|
return output, new_past_key_value, qkv_dict
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3MLP(nn.Module):
|
||||||
|
"""
|
||||||
|
Qwen3 MLP with SwiGLU activation.
|
||||||
|
|
||||||
|
Data Flow:
|
||||||
|
---------
|
||||||
|
hidden_states [batch, seq_len, hidden_size]
|
||||||
|
│
|
||||||
|
├──► gate_proj ──► gate [batch, seq_len, intermediate_size]
|
||||||
|
│
|
||||||
|
└──► up_proj ──► up [batch, seq_len, intermediate_size]
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
silu(gate) * up
|
||||||
|
│
|
||||||
|
▼
|
||||||
|
down_proj ──► output [batch, seq_len, hidden_size]
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_size: int, intermediate_size: int, bias: bool = False):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||||
|
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=bias)
|
||||||
|
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=bias)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
gate = self.gate_proj(x)
|
||||||
|
up = self.up_proj(x)
|
||||||
|
return self.down_proj(F.silu(gate) * up)
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3DecoderLayer(nn.Module):
|
||||||
|
"""Single Qwen3 Decoder Layer."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
mlp_bias: bool = False,
|
||||||
|
layer_idx: int = 0,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
|
# Pre-attention LayerNorm
|
||||||
|
self.input_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
|
||||||
|
# Self-attention
|
||||||
|
self.self_attn = Qwen3Attention(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Post-attention LayerNorm
|
||||||
|
self.post_attention_layernorm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
|
||||||
|
# MLP
|
||||||
|
self.mlp = Qwen3MLP(hidden_size, intermediate_size, bias=mlp_bias)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
output_qkv: bool = False,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[dict]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states: [batch, seq_len, hidden_size]
|
||||||
|
position_ids: [batch, seq_len]
|
||||||
|
attention_mask: Causal attention mask
|
||||||
|
past_key_value: KV cache for this layer
|
||||||
|
use_cache: Whether to return updated cache
|
||||||
|
output_qkv: Whether to output Q, K, V for debugging
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
hidden_states: [batch, seq_len, hidden_size]
|
||||||
|
past_key_value: Updated cache
|
||||||
|
qkv_dict: QKV tensors (if output_qkv=True)
|
||||||
|
"""
|
||||||
|
# === Self Attention Block ===
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
attn_output, new_past_key_value, qkv_dict = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_qkv=output_qkv,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = residual + attn_output
|
||||||
|
|
||||||
|
# === MLP Block ===
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
return hidden_states, new_past_key_value, qkv_dict
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3Model(nn.Module):
|
||||||
|
"""Qwen3 Transformer Model (without LM head)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
num_hidden_layers: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
mlp_bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
|
||||||
|
# Token embeddings
|
||||||
|
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
|
||||||
|
|
||||||
|
# Decoder layers
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
Qwen3DecoderLayer(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
mlp_bias=mlp_bias,
|
||||||
|
layer_idx=i,
|
||||||
|
)
|
||||||
|
for i in range(num_hidden_layers)
|
||||||
|
])
|
||||||
|
|
||||||
|
# Final LayerNorm
|
||||||
|
self.norm = Qwen3RMSNorm(hidden_size, eps=rms_norm_eps)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
output_qkv_layers: Optional[List[int]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_ids: [batch, seq_len]
|
||||||
|
position_ids: [batch, seq_len]
|
||||||
|
attention_mask: [batch, seq_len] or pre-computed 4D mask
|
||||||
|
past_key_values: List of (k, v) tuples for each layer
|
||||||
|
use_cache: Whether to return new cache
|
||||||
|
output_qkv_layers: List of layer indices to output QKV for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
hidden_states: [batch, seq_len, hidden_size]
|
||||||
|
new_past_key_values: Updated cache
|
||||||
|
qkv_outputs: {layer_idx: qkv_dict}
|
||||||
|
"""
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
|
||||||
|
# Embedding
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
# Position IDs
|
||||||
|
if position_ids is None:
|
||||||
|
past_len = past_key_values[0][0].shape[2] if past_key_values else 0
|
||||||
|
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device)
|
||||||
|
position_ids = position_ids.unsqueeze(0).expand(batch_size, -1)
|
||||||
|
|
||||||
|
# Attention mask (create causal mask if not provided)
|
||||||
|
if attention_mask is None or attention_mask.dim() == 2:
|
||||||
|
kv_seq_len = seq_len + (past_key_values[0][0].shape[2] if past_key_values else 0)
|
||||||
|
causal_mask = torch.triu(
|
||||||
|
torch.full((seq_len, kv_seq_len), float("-inf"), device=input_ids.device),
|
||||||
|
diagonal=kv_seq_len - seq_len + 1,
|
||||||
|
)
|
||||||
|
attention_mask = causal_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, kv_seq_len]
|
||||||
|
|
||||||
|
# Initialize cache list
|
||||||
|
new_past_key_values = [] if use_cache else None
|
||||||
|
qkv_outputs = {} if output_qkv_layers else None
|
||||||
|
|
||||||
|
# Decoder layers
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
past_kv = past_key_values[i] if past_key_values else None
|
||||||
|
output_qkv = output_qkv_layers is not None and i in output_qkv_layers
|
||||||
|
|
||||||
|
hidden_states, new_kv, qkv_dict = layer(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_value=past_kv,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_qkv=output_qkv,
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
new_past_key_values.append(new_kv)
|
||||||
|
if qkv_dict is not None:
|
||||||
|
qkv_outputs[i] = qkv_dict
|
||||||
|
|
||||||
|
# Final norm
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
return hidden_states, new_past_key_values, qkv_outputs
|
||||||
|
|
||||||
|
|
||||||
|
class Qwen3ForCausalLM(nn.Module):
|
||||||
|
"""Qwen3 Model with Language Modeling head."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size: int,
|
||||||
|
hidden_size: int,
|
||||||
|
intermediate_size: int,
|
||||||
|
num_hidden_layers: int,
|
||||||
|
num_attention_heads: int,
|
||||||
|
num_key_value_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_position_embeddings: int = 32768,
|
||||||
|
rope_theta: float = 10000.0,
|
||||||
|
rms_norm_eps: float = 1e-6,
|
||||||
|
attention_bias: bool = False,
|
||||||
|
mlp_bias: bool = False,
|
||||||
|
tie_word_embeddings: bool = True,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.tie_word_embeddings = tie_word_embeddings
|
||||||
|
|
||||||
|
# Transformer model
|
||||||
|
self.model = Qwen3Model(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_hidden_layers=num_hidden_layers,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
num_key_value_heads=num_key_value_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
max_position_embeddings=max_position_embeddings,
|
||||||
|
rope_theta=rope_theta,
|
||||||
|
rms_norm_eps=rms_norm_eps,
|
||||||
|
attention_bias=attention_bias,
|
||||||
|
mlp_bias=mlp_bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
# LM head
|
||||||
|
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
position_ids: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
||||||
|
use_cache: bool = False,
|
||||||
|
output_qkv_layers: Optional[List[int]] = None,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[List], Optional[dict]]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
input_ids: [batch, seq_len]
|
||||||
|
... (same as Qwen3Model)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logits: [batch, seq_len, vocab_size]
|
||||||
|
past_key_values: Updated KV cache
|
||||||
|
qkv_outputs: QKV tensors for specified layers
|
||||||
|
"""
|
||||||
|
hidden_states, new_past_key_values, qkv_outputs = self.model(
|
||||||
|
input_ids=input_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=use_cache,
|
||||||
|
output_qkv_layers=output_qkv_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
return logits, new_past_key_values, qkv_outputs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, model_path: str, dtype: torch.dtype = torch.float16) -> "Qwen3ForCausalLM":
|
||||||
|
"""
|
||||||
|
Load weights from a pretrained Qwen3 model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: Path to model directory containing config.json and model weights
|
||||||
|
dtype: Data type for model weights
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Initialized Qwen3ForCausalLM model
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
config_path = os.path.join(model_path, "config.json")
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
# Create model
|
||||||
|
model = cls(
|
||||||
|
vocab_size=config["vocab_size"],
|
||||||
|
hidden_size=config["hidden_size"],
|
||||||
|
intermediate_size=config["intermediate_size"],
|
||||||
|
num_hidden_layers=config["num_hidden_layers"],
|
||||||
|
num_attention_heads=config["num_attention_heads"],
|
||||||
|
num_key_value_heads=config.get("num_key_value_heads", config["num_attention_heads"]),
|
||||||
|
head_dim=config.get("head_dim", config["hidden_size"] // config["num_attention_heads"]),
|
||||||
|
max_position_embeddings=config.get("max_position_embeddings", 32768),
|
||||||
|
rope_theta=config.get("rope_theta", 10000.0),
|
||||||
|
rms_norm_eps=config.get("rms_norm_eps", 1e-6),
|
||||||
|
attention_bias=config.get("attention_bias", False),
|
||||||
|
mlp_bias=config.get("mlp_bias", False),
|
||||||
|
tie_word_embeddings=config.get("tie_word_embeddings", True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load weights
|
||||||
|
weight_files = sorted([
|
||||||
|
f for f in os.listdir(model_path)
|
||||||
|
if f.endswith(".safetensors")
|
||||||
|
])
|
||||||
|
|
||||||
|
state_dict = {}
|
||||||
|
for wf in weight_files:
|
||||||
|
state_dict.update(load_file(os.path.join(model_path, wf)))
|
||||||
|
|
||||||
|
# Load into model
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
|
||||||
|
# Tie lm_head weights to embed_tokens if configured
|
||||||
|
if model.tie_word_embeddings:
|
||||||
|
model.lm_head.weight = model.model.embed_tokens.weight
|
||||||
|
|
||||||
|
model = model.to(dtype)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
max_new_tokens: int = 32,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
do_sample: bool = True,
|
||||||
|
pad_token_id: Optional[int] = None,
|
||||||
|
eos_token_id: Optional[int] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Simple autoregressive generation."""
|
||||||
|
device = input_ids.device
|
||||||
|
batch_size, seq_len = input_ids.shape
|
||||||
|
past_key_values = None
|
||||||
|
generated = input_ids.clone()
|
||||||
|
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
if past_key_values is None:
|
||||||
|
current_input = generated
|
||||||
|
else:
|
||||||
|
current_input = generated[:, -1:]
|
||||||
|
|
||||||
|
logits, past_key_values, _ = self(
|
||||||
|
input_ids=current_input,
|
||||||
|
past_key_values=past_key_values,
|
||||||
|
use_cache=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
next_token_logits = logits[:, -1, :]
|
||||||
|
if temperature > 0 and do_sample:
|
||||||
|
next_token_logits = next_token_logits / temperature
|
||||||
|
probs = torch.softmax(next_token_logits, dim=-1)
|
||||||
|
next_token = torch.multinomial(probs, num_samples=1)
|
||||||
|
else:
|
||||||
|
next_token = next_token_logits.argmax(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
generated = torch.cat([generated, next_token], dim=1)
|
||||||
|
|
||||||
|
if eos_token_id is not None and (next_token == eos_token_id).all():
|
||||||
|
break
|
||||||
|
|
||||||
|
return generated
|
||||||
|
|
||||||
|
|
||||||
|
def print_computation_graph():
|
||||||
|
"""Print the computation graph for reference."""
|
||||||
|
print(__doc__)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print_computation_graph()
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
cmake_minimum_required(VERSION 3.18)
|
|
||||||
project(sgdma_test CUDA CXX)
|
|
||||||
|
|
||||||
# Find CUDA
|
|
||||||
enable_language(CUDA)
|
|
||||||
find_package(CUDA REQUIRED)
|
|
||||||
|
|
||||||
# Set C++ standard
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
|
||||||
set(CMAKE_CUDA_STANDARD 17)
|
|
||||||
|
|
||||||
# CUDA flags
|
|
||||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 --use_fast_math")
|
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
|
|
||||||
|
|
||||||
# Build test executable
|
|
||||||
add_executable(sgdma_test sgdma_test.cpp)
|
|
||||||
target_link_libraries(sgdma_test cudart)
|
|
||||||
|
|
||||||
# Set output directory
|
|
||||||
set_target_properties(sgdma_test PROPERTIES
|
|
||||||
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
|
|
||||||
)
|
|
||||||
@@ -1,326 +0,0 @@
|
|||||||
#include <cuda_runtime.h>
|
|
||||||
#include <iostream>
|
|
||||||
#include <chrono>
|
|
||||||
#include <cstring>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <iomanip>
|
|
||||||
|
|
||||||
// CUDA error checking macro
|
|
||||||
#define CUDA_CHECK(call) do { \
|
|
||||||
cudaError_t err = call; \
|
|
||||||
if (err != cudaSuccess) { \
|
|
||||||
std::cerr << "CUDA Error in " << __FILE__ << " at line " << __LINE__ << ": " \
|
|
||||||
<< cudaGetErrorString(err) << std::endl; \
|
|
||||||
exit(EXIT_FAILURE); \
|
|
||||||
} \
|
|
||||||
} while (0)
|
|
||||||
|
|
||||||
// Configuration matching nano-vllm realistic parameters
|
|
||||||
struct Config {
|
|
||||||
int num_layers = 32;
|
|
||||||
int num_blocks = 10; // Reduced from 100 to avoid huge allocation
|
|
||||||
int block_size = 4096;
|
|
||||||
int num_kv_heads = 8;
|
|
||||||
int head_dim = 128;
|
|
||||||
int dtype_size = 2; // float16
|
|
||||||
|
|
||||||
// Derived parameters (use size_t to avoid overflow)
|
|
||||||
size_t features_per_block() const { return (size_t)block_size * num_kv_heads * head_dim; }
|
|
||||||
size_t bytes_per_block() const { return features_per_block() * dtype_size; }
|
|
||||||
int total_blocks_per_layer() const { return num_blocks; }
|
|
||||||
size_t bytes_per_layer() const { return (size_t)num_blocks * bytes_per_block(); }
|
|
||||||
size_t total_bytes() const { return (size_t)num_layers * bytes_per_layer(); }
|
|
||||||
};
|
|
||||||
|
|
||||||
// Timer utility
|
|
||||||
class Timer {
|
|
||||||
std::chrono::high_resolution_clock::time_point start_time;
|
|
||||||
public:
|
|
||||||
void start() { start_time = std::chrono::high_resolution_clock::now(); }
|
|
||||||
double elapsed_ms() {
|
|
||||||
auto end = std::chrono::high_resolution_clock::now();
|
|
||||||
return std::chrono::duration<double, std::milli>(end - start_time).count();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Initialize CPU memory with test pattern
|
|
||||||
void init_test_data(void* data, size_t bytes, int seed) {
|
|
||||||
uint16_t* ptr = static_cast<uint16_t*>(data);
|
|
||||||
size_t num_elements = bytes / sizeof(uint16_t);
|
|
||||||
for (size_t i = 0; i < num_elements; i++) {
|
|
||||||
ptr[i] = static_cast<uint16_t>((seed + i) % 65536);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify data correctness
|
|
||||||
bool verify_data(const void* data1, const void* data2, size_t bytes) {
|
|
||||||
const uint16_t* p1 = static_cast<const uint16_t*>(data1);
|
|
||||||
const uint16_t* p2 = static_cast<const uint16_t*>(data2);
|
|
||||||
size_t num_elements = bytes / sizeof(uint16_t);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < num_elements; i++) {
|
|
||||||
if (p1[i] != p2[i]) {
|
|
||||||
std::cerr << "Mismatch at element " << i << ": "
|
|
||||||
<< p1[i] << " != " << p2[i] << std::endl;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Test 1: Basic Functionality Test
|
|
||||||
// ============================================================
|
|
||||||
bool test_basic_functionality(const Config& cfg) {
|
|
||||||
std::cout << "\n[Test 1] Basic Functionality Test" << std::endl;
|
|
||||||
std::cout << " Testing cudaMemcpy2D correctness with strided layout" << std::endl;
|
|
||||||
|
|
||||||
// Allocate strided CPU memory (pinned)
|
|
||||||
// Layout: [num_layers, num_blocks, block_features]
|
|
||||||
size_t total_bytes = cfg.total_bytes();
|
|
||||||
std::cout << " Allocating " << total_bytes / 1024.0 / 1024.0 / 1024.0 << " GB pinned memory..." << std::endl;
|
|
||||||
void* cpu_strided = nullptr;
|
|
||||||
CUDA_CHECK(cudaMallocHost(&cpu_strided, total_bytes));
|
|
||||||
std::cout << " CPU strided memory allocated at: " << cpu_strided << std::endl;
|
|
||||||
|
|
||||||
// Allocate GPU memory for one block (all layers)
|
|
||||||
size_t gpu_block_bytes = cfg.num_layers * cfg.bytes_per_block();
|
|
||||||
void* gpu_data = nullptr;
|
|
||||||
CUDA_CHECK(cudaMalloc(&gpu_data, gpu_block_bytes));
|
|
||||||
|
|
||||||
// Allocate CPU verify buffer
|
|
||||||
void* cpu_verify = nullptr;
|
|
||||||
CUDA_CHECK(cudaMallocHost(&cpu_verify, gpu_block_bytes));
|
|
||||||
|
|
||||||
// Initialize strided CPU memory
|
|
||||||
init_test_data(cpu_strided, total_bytes, 12345);
|
|
||||||
|
|
||||||
// Test: Copy block_id=5 from CPU to GPU using cudaMemcpy2D
|
|
||||||
int test_block_id = 5;
|
|
||||||
size_t spitch = cfg.bytes_per_layer(); // Source pitch (stride between layers)
|
|
||||||
size_t dpitch = cfg.bytes_per_block(); // Destination pitch (contiguous)
|
|
||||||
size_t width = cfg.bytes_per_block(); // Width to copy per row
|
|
||||||
size_t height = cfg.num_layers; // Number of rows (layers)
|
|
||||||
|
|
||||||
// Debug: print parameters
|
|
||||||
std::cout << " cudaMemcpy2D parameters:" << std::endl;
|
|
||||||
std::cout << " spitch: " << spitch << " bytes" << std::endl;
|
|
||||||
std::cout << " dpitch: " << dpitch << " bytes" << std::endl;
|
|
||||||
std::cout << " width: " << width << " bytes" << std::endl;
|
|
||||||
std::cout << " height: " << height << " rows" << std::endl;
|
|
||||||
std::cout << " dpitch >= width: " << (dpitch >= width ? "yes" : "no") << std::endl;
|
|
||||||
std::cout << " spitch >= width: " << (spitch >= width ? "yes" : "no") << std::endl;
|
|
||||||
|
|
||||||
// Calculate source pointer (first layer, block_id)
|
|
||||||
uint8_t* src_ptr = static_cast<uint8_t*>(cpu_strided) + test_block_id * cfg.bytes_per_block();
|
|
||||||
|
|
||||||
// H2D transfer
|
|
||||||
CUDA_CHECK(cudaMemcpy2D(
|
|
||||||
gpu_data, // dst
|
|
||||||
dpitch, // dpitch
|
|
||||||
src_ptr, // src
|
|
||||||
spitch, // spitch
|
|
||||||
width, // width
|
|
||||||
height, // height
|
|
||||||
cudaMemcpyHostToDevice
|
|
||||||
));
|
|
||||||
|
|
||||||
// D2H transfer back
|
|
||||||
CUDA_CHECK(cudaMemcpy2D(
|
|
||||||
cpu_verify, // dst
|
|
||||||
dpitch, // dpitch
|
|
||||||
gpu_data, // src
|
|
||||||
dpitch, // spitch
|
|
||||||
width, // width
|
|
||||||
height, // height
|
|
||||||
cudaMemcpyDeviceToHost
|
|
||||||
));
|
|
||||||
|
|
||||||
// Verify correctness
|
|
||||||
bool passed = true;
|
|
||||||
for (int layer = 0; layer < cfg.num_layers; layer++) {
|
|
||||||
uint8_t* expected_ptr = static_cast<uint8_t*>(cpu_strided) +
|
|
||||||
layer * cfg.bytes_per_layer() +
|
|
||||||
test_block_id * cfg.bytes_per_block();
|
|
||||||
uint8_t* actual_ptr = static_cast<uint8_t*>(cpu_verify) +
|
|
||||||
layer * cfg.bytes_per_block();
|
|
||||||
|
|
||||||
if (!verify_data(expected_ptr, actual_ptr, cfg.bytes_per_block())) {
|
|
||||||
std::cerr << " Verification failed at layer " << layer << std::endl;
|
|
||||||
passed = false;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Cleanup
|
|
||||||
CUDA_CHECK(cudaFreeHost(cpu_strided));
|
|
||||||
CUDA_CHECK(cudaFreeHost(cpu_verify));
|
|
||||||
CUDA_CHECK(cudaFree(gpu_data));
|
|
||||||
|
|
||||||
std::cout << " Result: " << (passed ? "PASSED ✓" : "FAILED ✗") << std::endl;
|
|
||||||
return passed;
|
|
||||||
}
|
|
||||||
|
|
||||||
// ============================================================
|
|
||||||
// Test 2: Performance Benchmark
|
|
||||||
// ============================================================
|
|
||||||
void test_performance_benchmark(const Config& cfg) {
|
|
||||||
std::cout << "\n[Test 2] Performance Benchmark" << std::endl;
|
|
||||||
std::cout << " Configuration:" << std::endl;
|
|
||||||
std::cout << " num_layers: " << cfg.num_layers << std::endl;
|
|
||||||
std::cout << " num_blocks: " << cfg.num_blocks << std::endl;
|
|
||||||
std::cout << " block_size: " << cfg.block_size << std::endl;
|
|
||||||
std::cout << " num_kv_heads: " << cfg.num_kv_heads << std::endl;
|
|
||||||
std::cout << " head_dim: " << cfg.head_dim << std::endl;
|
|
||||||
std::cout << " dtype_size: " << cfg.dtype_size << " bytes" << std::endl;
|
|
||||||
std::cout << " bytes_per_block: " << cfg.bytes_per_block() / 1024.0 << " KB" << std::endl;
|
|
||||||
std::cout << " total transfer size: " << cfg.num_layers * cfg.bytes_per_block() / 1024.0 / 1024.0 << " MB" << std::endl;
|
|
||||||
|
|
||||||
const int num_iterations = 100;
|
|
||||||
const int warmup = 10;
|
|
||||||
int test_block_id = 5;
|
|
||||||
|
|
||||||
// Allocate memory
|
|
||||||
size_t total_bytes = cfg.total_bytes();
|
|
||||||
void* cpu_strided = nullptr;
|
|
||||||
CUDA_CHECK(cudaMallocHost(&cpu_strided, total_bytes));
|
|
||||||
|
|
||||||
void* cpu_contiguous = nullptr;
|
|
||||||
size_t gpu_block_bytes = cfg.num_layers * cfg.bytes_per_block();
|
|
||||||
CUDA_CHECK(cudaMallocHost(&cpu_contiguous, gpu_block_bytes));
|
|
||||||
|
|
||||||
void* gpu_data = nullptr;
|
|
||||||
CUDA_CHECK(cudaMalloc(&gpu_data, gpu_block_bytes));
|
|
||||||
|
|
||||||
init_test_data(cpu_strided, total_bytes, 12345);
|
|
||||||
init_test_data(cpu_contiguous, gpu_block_bytes, 12345);
|
|
||||||
|
|
||||||
Timer timer;
|
|
||||||
double elapsed;
|
|
||||||
double bandwidth;
|
|
||||||
|
|
||||||
// ========================================
|
|
||||||
// Method A: cudaMemcpy2D with strided layout
|
|
||||||
// ========================================
|
|
||||||
size_t spitch = cfg.bytes_per_layer();
|
|
||||||
size_t dpitch = cfg.bytes_per_block();
|
|
||||||
size_t width = cfg.bytes_per_block();
|
|
||||||
size_t height = cfg.num_layers;
|
|
||||||
uint8_t* src_ptr = static_cast<uint8_t*>(cpu_strided) + test_block_id * cfg.bytes_per_block();
|
|
||||||
|
|
||||||
// Warmup
|
|
||||||
for (int i = 0; i < warmup; i++) {
|
|
||||||
CUDA_CHECK(cudaMemcpy2D(gpu_data, dpitch, src_ptr, spitch, width, height, cudaMemcpyHostToDevice));
|
|
||||||
}
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
|
|
||||||
// Benchmark
|
|
||||||
timer.start();
|
|
||||||
for (int i = 0; i < num_iterations; i++) {
|
|
||||||
CUDA_CHECK(cudaMemcpy2D(gpu_data, dpitch, src_ptr, spitch, width, height, cudaMemcpyHostToDevice));
|
|
||||||
}
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
elapsed = timer.elapsed_ms();
|
|
||||||
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
|
|
||||||
|
|
||||||
std::cout << "\n Method A (cudaMemcpy2D strided):" << std::endl;
|
|
||||||
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
|
|
||||||
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
|
|
||||||
double method_a_bw = bandwidth;
|
|
||||||
|
|
||||||
// ========================================
|
|
||||||
// Method B: cudaMemcpy with contiguous layout (baseline)
|
|
||||||
// ========================================
|
|
||||||
// Warmup
|
|
||||||
for (int i = 0; i < warmup; i++) {
|
|
||||||
CUDA_CHECK(cudaMemcpy(gpu_data, cpu_contiguous, gpu_block_bytes, cudaMemcpyHostToDevice));
|
|
||||||
}
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
|
|
||||||
// Benchmark
|
|
||||||
timer.start();
|
|
||||||
for (int i = 0; i < num_iterations; i++) {
|
|
||||||
CUDA_CHECK(cudaMemcpy(gpu_data, cpu_contiguous, gpu_block_bytes, cudaMemcpyHostToDevice));
|
|
||||||
}
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
elapsed = timer.elapsed_ms();
|
|
||||||
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
|
|
||||||
|
|
||||||
std::cout << "\n Method B (cudaMemcpy contiguous):" << std::endl;
|
|
||||||
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
|
|
||||||
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
|
|
||||||
double method_b_bw = bandwidth;
|
|
||||||
|
|
||||||
// ========================================
|
|
||||||
// Method C: Layer-by-layer copy (simulate PyTorch non-contiguous)
|
|
||||||
// ========================================
|
|
||||||
// Warmup
|
|
||||||
for (int i = 0; i < warmup; i++) {
|
|
||||||
for (int layer = 0; layer < cfg.num_layers; layer++) {
|
|
||||||
uint8_t* src_layer = static_cast<uint8_t*>(cpu_strided) +
|
|
||||||
layer * cfg.bytes_per_layer() +
|
|
||||||
test_block_id * cfg.bytes_per_block();
|
|
||||||
uint8_t* dst_layer = static_cast<uint8_t*>(gpu_data) + layer * cfg.bytes_per_block();
|
|
||||||
CUDA_CHECK(cudaMemcpy(dst_layer, src_layer, cfg.bytes_per_block(), cudaMemcpyHostToDevice));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
|
|
||||||
// Benchmark
|
|
||||||
timer.start();
|
|
||||||
for (int i = 0; i < num_iterations; i++) {
|
|
||||||
for (int layer = 0; layer < cfg.num_layers; layer++) {
|
|
||||||
uint8_t* src_layer = static_cast<uint8_t*>(cpu_strided) +
|
|
||||||
layer * cfg.bytes_per_layer() +
|
|
||||||
test_block_id * cfg.bytes_per_block();
|
|
||||||
uint8_t* dst_layer = static_cast<uint8_t*>(gpu_data) + layer * cfg.bytes_per_block();
|
|
||||||
CUDA_CHECK(cudaMemcpy(dst_layer, src_layer, cfg.bytes_per_block(), cudaMemcpyHostToDevice));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
CUDA_CHECK(cudaDeviceSynchronize());
|
|
||||||
elapsed = timer.elapsed_ms();
|
|
||||||
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
|
|
||||||
|
|
||||||
std::cout << "\n Method C (layer-by-layer copy):" << std::endl;
|
|
||||||
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
|
|
||||||
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
|
|
||||||
double method_c_bw = bandwidth;
|
|
||||||
|
|
||||||
// Summary
|
|
||||||
std::cout << "\n ========================================" << std::endl;
|
|
||||||
std::cout << " Performance Summary:" << std::endl;
|
|
||||||
std::cout << " Method A vs Method B: " << std::setprecision(2) << (method_a_bw / method_b_bw * 100) << "%" << std::endl;
|
|
||||||
std::cout << " Method A vs Method C: " << std::setprecision(2) << (method_a_bw / method_c_bw) << "x speedup" << std::endl;
|
|
||||||
std::cout << " ========================================" << std::endl;
|
|
||||||
|
|
||||||
// Cleanup
|
|
||||||
CUDA_CHECK(cudaFreeHost(cpu_strided));
|
|
||||||
CUDA_CHECK(cudaFreeHost(cpu_contiguous));
|
|
||||||
CUDA_CHECK(cudaFree(gpu_data));
|
|
||||||
}
|
|
||||||
|
|
||||||
int main() {
|
|
||||||
std::cout << "=== cudaMemcpy2D Test ===" << std::endl;
|
|
||||||
|
|
||||||
// Print CUDA device info
|
|
||||||
int device;
|
|
||||||
CUDA_CHECK(cudaGetDevice(&device));
|
|
||||||
cudaDeviceProp prop;
|
|
||||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
|
|
||||||
std::cout << "Using GPU: " << prop.name << std::endl;
|
|
||||||
std::cout << "Memory Clock Rate: " << prop.memoryClockRate / 1000 << " MHz" << std::endl;
|
|
||||||
std::cout << "Memory Bus Width: " << prop.memoryBusWidth << " bits" << std::endl;
|
|
||||||
std::cout << "Peak Memory Bandwidth: " <<
|
|
||||||
2.0 * prop.memoryClockRate * (prop.memoryBusWidth / 8) / 1.0e6 << " GB/s" << std::endl;
|
|
||||||
|
|
||||||
Config cfg;
|
|
||||||
|
|
||||||
// Run tests
|
|
||||||
bool test1_passed = test_basic_functionality(cfg);
|
|
||||||
test_performance_benchmark(cfg);
|
|
||||||
|
|
||||||
std::cout << "\n=== Test Complete ===" << std::endl;
|
|
||||||
std::cout << "All tests " << (test1_passed ? "PASSED ✓" : "FAILED ✗") << std::endl;
|
|
||||||
|
|
||||||
return test1_passed ? 0 : 1;
|
|
||||||
}
|
|
||||||
@@ -1,297 +0,0 @@
|
|||||||
"""
|
|
||||||
Test Attention layer with KV cache offload - N-way Pipeline.
|
|
||||||
|
|
||||||
This test demonstrates and verifies the N-way pipeline with:
|
|
||||||
- Per-slot transfer streams for parallel H2D
|
|
||||||
- Dedicated compute stream (avoids CUDA default stream implicit sync)
|
|
||||||
- Pre-load phase + main loop with immediate slot reuse
|
|
||||||
|
|
||||||
Key difference from previous test:
|
|
||||||
- We first pre-fill many chunks to CPU cache
|
|
||||||
- Then simulate processing a new chunk that loads ALL previous blocks
|
|
||||||
- This exercises the full N-way pipeline with many blocks in flight
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from nanovllm.layers.attention import Attention
|
|
||||||
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
|
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
from nanovllm.engine.sequence import Sequence
|
|
||||||
from nanovllm.utils.context import set_context, reset_context
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
NUM_LAYERS = 8
|
|
||||||
NUM_HEADS = 8
|
|
||||||
NUM_KV_HEADS = 8
|
|
||||||
HEAD_DIM = 64
|
|
||||||
BLOCK_SIZE = 1024
|
|
||||||
CHUNK_SIZE = 1024
|
|
||||||
|
|
||||||
NUM_GPU_SLOTS = 6 # N-way pipeline with 6 slots
|
|
||||||
NUM_CPU_BLOCKS = 16 # Many blocks to load from CPU
|
|
||||||
|
|
||||||
DTYPE = torch.bfloat16
|
|
||||||
DEVICE = "cuda"
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Setup
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def create_manager():
|
|
||||||
manager = HybridKVCacheManager(
|
|
||||||
num_gpu_slots=NUM_GPU_SLOTS,
|
|
||||||
num_cpu_blocks=NUM_CPU_BLOCKS,
|
|
||||||
block_size=BLOCK_SIZE,
|
|
||||||
)
|
|
||||||
manager.allocate_cache(
|
|
||||||
num_layers=NUM_LAYERS,
|
|
||||||
num_kv_heads=NUM_KV_HEADS,
|
|
||||||
head_dim=HEAD_DIM,
|
|
||||||
dtype=DTYPE,
|
|
||||||
)
|
|
||||||
return manager
|
|
||||||
|
|
||||||
|
|
||||||
def create_attention_layers(manager):
|
|
||||||
layers = []
|
|
||||||
for layer_id in range(NUM_LAYERS):
|
|
||||||
attn = Attention(
|
|
||||||
num_heads=NUM_HEADS,
|
|
||||||
head_dim=HEAD_DIM,
|
|
||||||
scale=HEAD_DIM ** -0.5,
|
|
||||||
num_kv_heads=NUM_KV_HEADS,
|
|
||||||
)
|
|
||||||
attn.layer_id = layer_id
|
|
||||||
k_cache, v_cache = manager.get_layer_cache(layer_id)
|
|
||||||
attn.k_cache = k_cache
|
|
||||||
attn.v_cache = v_cache
|
|
||||||
layers.append(attn.to(DEVICE))
|
|
||||||
return layers
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Pre-fill CPU cache with random data
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def prefill_cpu_cache(manager, num_blocks):
|
|
||||||
"""
|
|
||||||
Fill CPU cache with random KV data for num_blocks blocks.
|
|
||||||
This simulates having already processed many chunks.
|
|
||||||
"""
|
|
||||||
offload_engine = manager.offload_engine
|
|
||||||
|
|
||||||
for block_id in range(num_blocks):
|
|
||||||
# Generate random KV data for all layers
|
|
||||||
for layer_id in range(NUM_LAYERS):
|
|
||||||
k_data = torch.randn(
|
|
||||||
BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM,
|
|
||||||
dtype=DTYPE, device=DEVICE
|
|
||||||
)
|
|
||||||
v_data = torch.randn(
|
|
||||||
BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM,
|
|
||||||
dtype=DTYPE, device=DEVICE
|
|
||||||
)
|
|
||||||
|
|
||||||
# Copy to CPU cache
|
|
||||||
offload_engine.k_cache_cpu[layer_id, block_id].copy_(k_data)
|
|
||||||
offload_engine.v_cache_cpu[layer_id, block_id].copy_(v_data)
|
|
||||||
|
|
||||||
return list(range(num_blocks))
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Simulate N-way Pipeline (mirrors attention.py logic)
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def simulate_nway_pipeline(
|
|
||||||
layer_id: int,
|
|
||||||
q_batched: torch.Tensor,
|
|
||||||
cpu_block_table: list,
|
|
||||||
load_slots: list,
|
|
||||||
offload_engine,
|
|
||||||
scale: float,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Simulate N-way pipeline for a single layer.
|
|
||||||
This mirrors the logic in Attention._ring_buffer_pipeline_load().
|
|
||||||
"""
|
|
||||||
num_blocks = len(cpu_block_table)
|
|
||||||
num_slots = len(load_slots)
|
|
||||||
|
|
||||||
o_acc, lse_acc = None, None
|
|
||||||
|
|
||||||
# Phase 1: Pre-load up to num_slots blocks
|
|
||||||
num_preload = min(num_slots, num_blocks)
|
|
||||||
torch.cuda.nvtx.range_push(f"Phase1_Preload: L{layer_id}")
|
|
||||||
for i in range(num_preload):
|
|
||||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
|
|
||||||
# Phase 2: Main loop with compute_stream
|
|
||||||
compute_stream = offload_engine.compute_stream
|
|
||||||
|
|
||||||
for block_idx in range(num_blocks):
|
|
||||||
torch.cuda.nvtx.range_push(f"Block: L{layer_id} B{block_idx}")
|
|
||||||
|
|
||||||
current_slot = load_slots[block_idx % num_slots]
|
|
||||||
|
|
||||||
# Wait for transfer
|
|
||||||
offload_engine.wait_slot_layer(current_slot, layer_id)
|
|
||||||
|
|
||||||
# Compute on dedicated stream
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{layer_id} B{block_idx}")
|
|
||||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id)
|
|
||||||
prev_o, prev_lse = flash_attn_with_lse(
|
|
||||||
q_batched, prev_k, prev_v,
|
|
||||||
softmax_scale=scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
offload_engine.record_slot_compute_done(current_slot, layer_id)
|
|
||||||
|
|
||||||
# Start next transfer (reuse current_slot)
|
|
||||||
next_block_idx = block_idx + num_slots
|
|
||||||
if next_block_idx < num_blocks:
|
|
||||||
offload_engine.load_to_slot_layer(
|
|
||||||
current_slot, layer_id, cpu_block_table[next_block_idx]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge
|
|
||||||
with torch.cuda.stream(compute_stream):
|
|
||||||
if o_acc is None:
|
|
||||||
o_acc, lse_acc = prev_o, prev_lse
|
|
||||||
else:
|
|
||||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
|
||||||
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
|
|
||||||
return o_acc, lse_acc
|
|
||||||
|
|
||||||
|
|
||||||
def simulate_full_forward(layers, manager, cpu_block_table, chunk_size):
|
|
||||||
"""
|
|
||||||
Simulate forward pass through all layers, loading previous blocks from CPU.
|
|
||||||
This is the key test: many blocks loaded via N-way pipeline.
|
|
||||||
"""
|
|
||||||
offload_engine = manager.offload_engine
|
|
||||||
|
|
||||||
# Current chunk index (we're processing the "next" chunk after all prefilled ones)
|
|
||||||
current_chunk_idx = len(cpu_block_table)
|
|
||||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
|
||||||
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
|
||||||
|
|
||||||
# Random query for attention
|
|
||||||
q = torch.randn(1, chunk_size, NUM_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
for layer in layers:
|
|
||||||
torch.cuda.nvtx.range_push(f"Layer: {layer.layer_id}")
|
|
||||||
|
|
||||||
o_acc, lse_acc = simulate_nway_pipeline(
|
|
||||||
layer.layer_id,
|
|
||||||
q,
|
|
||||||
cpu_block_table,
|
|
||||||
load_slots,
|
|
||||||
offload_engine,
|
|
||||||
layer.scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs.append(o_acc)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
|
|
||||||
return outputs
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("Test: N-way Pipeline with CPU Offload")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 1. Setup
|
|
||||||
print("\n[1] Creating manager and attention layers...")
|
|
||||||
manager = create_manager()
|
|
||||||
layers = create_attention_layers(manager)
|
|
||||||
offload_engine = manager.offload_engine
|
|
||||||
|
|
||||||
print(f" - GPU slots: {NUM_GPU_SLOTS}")
|
|
||||||
print(f" - CPU blocks: {NUM_CPU_BLOCKS}")
|
|
||||||
print(f" - Per-slot streams: {len(offload_engine.slot_transfer_streams)}")
|
|
||||||
print(f" - Compute stream: {offload_engine.compute_stream}")
|
|
||||||
|
|
||||||
# 2. Pre-fill CPU cache
|
|
||||||
NUM_PREV_BLOCKS = 12 # Many blocks to load via N-way pipeline
|
|
||||||
print(f"\n[2] Pre-filling {NUM_PREV_BLOCKS} blocks to CPU cache...")
|
|
||||||
cpu_block_table = prefill_cpu_cache(manager, NUM_PREV_BLOCKS)
|
|
||||||
print(f" - CPU blocks filled: {cpu_block_table}")
|
|
||||||
|
|
||||||
# 3. Verify pipeline configuration
|
|
||||||
current_chunk_idx = NUM_PREV_BLOCKS
|
|
||||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
|
||||||
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
|
||||||
print(f"\n[3] Pipeline configuration for chunk {current_chunk_idx}:")
|
|
||||||
print(f" - Write slot: {write_slot}")
|
|
||||||
print(f" - Load slots: {load_slots}")
|
|
||||||
print(f" - Pipeline depth (N-way): {len(load_slots)}")
|
|
||||||
assert len(load_slots) == NUM_GPU_SLOTS - 1, f"Expected {NUM_GPU_SLOTS - 1} load slots"
|
|
||||||
|
|
||||||
# 4. Warmup
|
|
||||||
print("\n[4] Warmup (3 iterations)...")
|
|
||||||
for i in range(3):
|
|
||||||
outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
print(f" - Warmup {i+1}/3 done")
|
|
||||||
|
|
||||||
# 5. Benchmark
|
|
||||||
NUM_ITERS = 10
|
|
||||||
print(f"\n[5] Benchmark ({NUM_ITERS} iterations)...")
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
start_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
end_event = torch.cuda.Event(enable_timing=True)
|
|
||||||
|
|
||||||
start_event.record()
|
|
||||||
for i in range(NUM_ITERS):
|
|
||||||
torch.cuda.nvtx.range_push(f"Iteration_{i}")
|
|
||||||
outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE)
|
|
||||||
torch.cuda.nvtx.range_pop()
|
|
||||||
end_event.record()
|
|
||||||
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elapsed_ms = start_event.elapsed_time(end_event)
|
|
||||||
|
|
||||||
# Stats
|
|
||||||
total_blocks_loaded = NUM_PREV_BLOCKS * NUM_LAYERS * NUM_ITERS
|
|
||||||
blocks_per_sec = total_blocks_loaded / (elapsed_ms / 1000)
|
|
||||||
total_tokens = NUM_PREV_BLOCKS * BLOCK_SIZE * NUM_LAYERS * NUM_ITERS
|
|
||||||
tokens_per_sec = total_tokens / (elapsed_ms / 1000)
|
|
||||||
|
|
||||||
print(f"\n[6] Results:")
|
|
||||||
print(f" - Total time: {elapsed_ms:.2f} ms")
|
|
||||||
print(f" - Per iteration: {elapsed_ms / NUM_ITERS:.2f} ms")
|
|
||||||
print(f" - Blocks loaded: {total_blocks_loaded} ({blocks_per_sec:.0f} blocks/s)")
|
|
||||||
print(f" - Tokens processed: {total_tokens} ({tokens_per_sec:.0f} tok/s)")
|
|
||||||
|
|
||||||
# 7. Verification
|
|
||||||
print("\n[7] Verification:")
|
|
||||||
assert len(outputs) == NUM_LAYERS, f"Expected {NUM_LAYERS} outputs"
|
|
||||||
for i, o in enumerate(outputs):
|
|
||||||
assert o is not None, f"Layer {i} output is None"
|
|
||||||
assert o.shape == (1, CHUNK_SIZE, NUM_HEADS, HEAD_DIM), f"Layer {i} shape mismatch"
|
|
||||||
print(" - All layer outputs valid ✓")
|
|
||||||
print(" - N-way pipeline executed correctly ✓")
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
reset_context()
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("test_attention_offload: PASSED")
|
|
||||||
print("=" * 60)
|
|
||||||
@@ -1,169 +0,0 @@
|
|||||||
"""
|
|
||||||
Test script for chunked attention correctness.
|
|
||||||
|
|
||||||
Validates that chunked prefill using flash_attn_with_lse + merge_attention_outputs
|
|
||||||
produces the same result as full flash_attn_varlen_func.
|
|
||||||
|
|
||||||
Scenario: Simulating chunked prefill where we process query chunk by chunk.
|
|
||||||
For each query chunk i:
|
|
||||||
- KV contains all tokens from chunk 0 to chunk i
|
|
||||||
- Previous KV chunks (0 to i-1): full attention (no causal mask)
|
|
||||||
- Current KV chunk (i): causal attention (diagonal block)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_func
|
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Utility Functions
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def compute_chunked_prefill_for_chunk(
|
|
||||||
q_chunk: torch.Tensor,
|
|
||||||
kv_chunks: list,
|
|
||||||
current_chunk_idx: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute attention for a single query chunk against all KV chunks up to current.
|
|
||||||
|
|
||||||
This simulates chunked prefill for query chunk `current_chunk_idx`:
|
|
||||||
- KV chunks 0 to current_chunk_idx-1: full attention (all previous tokens visible)
|
|
||||||
- KV chunk current_chunk_idx: causal attention (diagonal block)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q_chunk: [batch, chunk_size, nheads, headdim] - current query chunk
|
|
||||||
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
|
|
||||||
current_chunk_idx: Index of the current chunk being processed
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
out: [batch, chunk_size, nheads, headdim]
|
|
||||||
"""
|
|
||||||
accumulated_o = None
|
|
||||||
accumulated_lse = None
|
|
||||||
|
|
||||||
for i in range(current_chunk_idx + 1):
|
|
||||||
k_chunk, v_chunk = kv_chunks[i]
|
|
||||||
|
|
||||||
# Previous chunks: no causal mask (all tokens visible)
|
|
||||||
# Current chunk (diagonal): causal mask
|
|
||||||
is_diagonal = (i == current_chunk_idx)
|
|
||||||
|
|
||||||
chunk_o, chunk_lse = flash_attn_with_lse(
|
|
||||||
q_chunk, k_chunk, v_chunk, causal=is_diagonal
|
|
||||||
)
|
|
||||||
|
|
||||||
if accumulated_o is None:
|
|
||||||
accumulated_o = chunk_o
|
|
||||||
accumulated_lse = chunk_lse
|
|
||||||
else:
|
|
||||||
accumulated_o, accumulated_lse = merge_attention_outputs(
|
|
||||||
accumulated_o, accumulated_lse,
|
|
||||||
chunk_o, chunk_lse
|
|
||||||
)
|
|
||||||
|
|
||||||
return accumulated_o
|
|
||||||
|
|
||||||
|
|
||||||
def compute_reference_causal(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute reference causal attention using flash_attn_func.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q, k, v: [batch, seqlen, nheads, headdim]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
out: [batch, seqlen, nheads, headdim]
|
|
||||||
"""
|
|
||||||
return flash_attn_func(q, k, v, causal=True)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test Script
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
# Test configurations: (batch, num_chunks, chunk_size, nheads, headdim)
|
|
||||||
TEST_CASES = [
|
|
||||||
(1, 4, 256, 8, 128),
|
|
||||||
(1, 4, 512, 8, 128),
|
|
||||||
(1, 8, 512, 8, 128),
|
|
||||||
(1, 32, 1024, 8, 128),
|
|
||||||
(1, 32, 1024, 32, 128), # More heads
|
|
||||||
(1, 32, 256, 8, 64), # Smaller head dim
|
|
||||||
]
|
|
||||||
|
|
||||||
DTYPES = [torch.float16, torch.bfloat16]
|
|
||||||
|
|
||||||
print("=" * 80)
|
|
||||||
print("Test: Chunked Prefill Attention vs Reference (flash_attn_func causal)")
|
|
||||||
print("=" * 80)
|
|
||||||
print("Simulating chunked prefill: Q chunk attends to all KV chunks up to current")
|
|
||||||
print(" - Previous KV chunks: full attention (no causal mask)")
|
|
||||||
print(" - Current KV chunk (diagonal): causal attention")
|
|
||||||
print()
|
|
||||||
|
|
||||||
all_passed = True
|
|
||||||
|
|
||||||
for dtype in DTYPES:
|
|
||||||
print(f"--- dtype: {dtype} ---")
|
|
||||||
|
|
||||||
for batch, num_chunks, chunk_size, nheads, headdim in TEST_CASES:
|
|
||||||
seqlen = num_chunks * chunk_size
|
|
||||||
|
|
||||||
# Generate full Q, K, V
|
|
||||||
q_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
|
||||||
k_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
|
||||||
v_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
|
||||||
|
|
||||||
# Reference: full causal attention
|
|
||||||
out_ref = compute_reference_causal(q_full, k_full, v_full)
|
|
||||||
|
|
||||||
# Split into chunks
|
|
||||||
q_chunks = [q_full[:, i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)]
|
|
||||||
kv_chunks = [
|
|
||||||
(k_full[:, i*chunk_size:(i+1)*chunk_size],
|
|
||||||
v_full[:, i*chunk_size:(i+1)*chunk_size])
|
|
||||||
for i in range(num_chunks)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Compute chunked prefill for each query chunk
|
|
||||||
out_chunks = []
|
|
||||||
for chunk_idx in range(num_chunks):
|
|
||||||
chunk_out = compute_chunked_prefill_for_chunk(
|
|
||||||
q_chunks[chunk_idx],
|
|
||||||
kv_chunks,
|
|
||||||
chunk_idx,
|
|
||||||
)
|
|
||||||
out_chunks.append(chunk_out)
|
|
||||||
|
|
||||||
# Concatenate chunked outputs
|
|
||||||
out_chunked = torch.cat(out_chunks, dim=1)
|
|
||||||
|
|
||||||
# Compare
|
|
||||||
diff = (out_ref - out_chunked).abs()
|
|
||||||
max_diff = diff.max().item()
|
|
||||||
mean_diff = diff.mean().item()
|
|
||||||
|
|
||||||
# Tolerance: fp16/bf16 have limited precision
|
|
||||||
tol = 1e-2
|
|
||||||
passed = max_diff < tol
|
|
||||||
all_passed = all_passed and passed
|
|
||||||
|
|
||||||
status = "PASS" if passed else "FAIL"
|
|
||||||
print(
|
|
||||||
f"[{status}] seqlen={seqlen:5d} chunks={num_chunks} "
|
|
||||||
f"chunk_size={chunk_size:4d} heads={nheads:2d} dim={headdim:3d} "
|
|
||||||
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print()
|
|
||||||
|
|
||||||
print("=" * 80)
|
|
||||||
assert all_passed, "Some tests failed!"
|
|
||||||
print("test_chunked_attention: PASSED")
|
|
||||||
@@ -1,391 +0,0 @@
|
|||||||
"""
|
|
||||||
Correctness test for chunked decode attention.
|
|
||||||
|
|
||||||
Captures Q and output during inference, then computes reference using
|
|
||||||
CPU KV cache with standard flash attention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from random import randint, seed
|
|
||||||
from typing import Dict, List
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from nanovllm.utils.context import get_context
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func
|
|
||||||
|
|
||||||
# Config
|
|
||||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
|
||||||
MAX_MODEL_LEN = 128 * 1024
|
|
||||||
NUM_GPU_BLOCKS = 2
|
|
||||||
INPUT_LEN = 16 * 1024
|
|
||||||
NUM_DECODE_TOKENS = 5
|
|
||||||
BLOCK_SIZE = 1024
|
|
||||||
|
|
||||||
# State
|
|
||||||
prefill_captures: List[Dict] = []
|
|
||||||
decode_captures: List[Dict] = []
|
|
||||||
|
|
||||||
|
|
||||||
def make_ones_injection_hook():
|
|
||||||
"""Inject Q=K=V=1.0 for deterministic testing."""
|
|
||||||
def hook(module, inputs):
|
|
||||||
q, k, v = inputs[0], inputs[1], inputs[2]
|
|
||||||
q_ones = torch.ones_like(q)
|
|
||||||
k_ones = torch.ones_like(k)
|
|
||||||
v_ones = torch.ones_like(v)
|
|
||||||
return (q_ones, k_ones, v_ones) + inputs[3:]
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def make_capture_hook(layer_id: int):
|
|
||||||
"""Capture Q, K, V, output during inference."""
|
|
||||||
def hook(module, inputs, output):
|
|
||||||
ctx = get_context()
|
|
||||||
q, k, v = inputs
|
|
||||||
|
|
||||||
if ctx.is_prefill:
|
|
||||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
||||||
prefill_captures.append({
|
|
||||||
'layer_id': layer_id,
|
|
||||||
'chunk_idx': chunk_idx,
|
|
||||||
'q': q.clone().cpu(),
|
|
||||||
'k': k.clone().cpu(),
|
|
||||||
'v': v.clone().cpu(),
|
|
||||||
'output': output.clone().cpu(),
|
|
||||||
})
|
|
||||||
else:
|
|
||||||
decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id])
|
|
||||||
decode_captures.append({
|
|
||||||
'layer_id': layer_id,
|
|
||||||
'decode_step': decode_step,
|
|
||||||
'q': q.clone().cpu(),
|
|
||||||
'k': k.clone().cpu(),
|
|
||||||
'v': v.clone().cpu(),
|
|
||||||
'output': output.clone().cpu(),
|
|
||||||
})
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
|
|
||||||
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
|
|
||||||
block_size: int, num_prefill_chunks: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute reference decode output using CPU KV cache and standard flash attention.
|
|
||||||
|
|
||||||
For decode, query attends to:
|
|
||||||
1. All prefill KV (from CPU cache)
|
|
||||||
2. All previous decode tokens (from captured decode k, v)
|
|
||||||
"""
|
|
||||||
# Get decode capture for this layer and step
|
|
||||||
decode_cap = None
|
|
||||||
for c in decode_captures:
|
|
||||||
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
|
|
||||||
decode_cap = c
|
|
||||||
break
|
|
||||||
|
|
||||||
if decode_cap is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Query: single decode token
|
|
||||||
q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
|
|
||||||
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
|
|
||||||
|
|
||||||
# Collect all K, V: prefill chunks from captures + decode tokens from captures
|
|
||||||
# NOTE: We use prefill captures directly instead of CPU cache because
|
|
||||||
# the CPU block ID may not equal the chunk index.
|
|
||||||
all_k = []
|
|
||||||
all_v = []
|
|
||||||
|
|
||||||
# 1. Prefill chunks from captures (use captured K/V, not CPU cache)
|
|
||||||
for cidx in range(num_prefill_chunks):
|
|
||||||
prefill_cap = None
|
|
||||||
for c in prefill_captures:
|
|
||||||
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
|
|
||||||
prefill_cap = c
|
|
||||||
break
|
|
||||||
|
|
||||||
if prefill_cap is not None:
|
|
||||||
# Use captured K/V directly (guaranteed to be correct layer data)
|
|
||||||
all_k.append(prefill_cap['k'].cuda())
|
|
||||||
all_v.append(prefill_cap['v'].cuda())
|
|
||||||
|
|
||||||
# 2. Decode tokens from captures (up to and including current step)
|
|
||||||
for step in range(decode_step + 1):
|
|
||||||
for c in decode_captures:
|
|
||||||
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
|
||||||
all_k.append(c['k'].cuda())
|
|
||||||
all_v.append(c['v'].cuda())
|
|
||||||
break
|
|
||||||
|
|
||||||
if not all_k:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Concatenate all K, V
|
|
||||||
full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim]
|
|
||||||
full_v = torch.cat(all_v, dim=0).unsqueeze(0)
|
|
||||||
|
|
||||||
# Run flash attention (non-causal since we explicitly control what KV to include)
|
|
||||||
output = flash_attn_func(
|
|
||||||
q_batched, full_k, full_v,
|
|
||||||
softmax_scale=scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
MODEL_PATH,
|
|
||||||
enforce_eager=True,
|
|
||||||
max_model_len=MAX_MODEL_LEN,
|
|
||||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
kvcache_block_size=BLOCK_SIZE,
|
|
||||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
|
||||||
dtype="float16",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get model info
|
|
||||||
num_layers = len(llm.model_runner.model.model.layers)
|
|
||||||
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
|
||||||
scale = head_dim ** -0.5
|
|
||||||
|
|
||||||
# Register hooks
|
|
||||||
hooks = []
|
|
||||||
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
|
|
||||||
# Pre-hook: inject all ones for Q, K, V
|
|
||||||
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
|
|
||||||
# hooks.append(pre_hook)
|
|
||||||
# Post-hook: capture Q, K, V, output
|
|
||||||
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
|
|
||||||
hooks.append(post_hook)
|
|
||||||
|
|
||||||
# Run inference
|
|
||||||
seed(42)
|
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
|
||||||
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False)
|
|
||||||
|
|
||||||
# Remove hooks
|
|
||||||
for hook in hooks:
|
|
||||||
hook.remove()
|
|
||||||
|
|
||||||
# Get CPU cache reference
|
|
||||||
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
|
||||||
k_cache_cpu = offload_engine.k_cache_cpu.clone()
|
|
||||||
v_cache_cpu = offload_engine.v_cache_cpu.clone()
|
|
||||||
|
|
||||||
# Calculate number of prefill chunks
|
|
||||||
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
|
|
||||||
|
|
||||||
# Debug: Compare decode_buffer with captured K/V
|
|
||||||
print("\n=== DEBUG: Comparing decode_buffer with captured K/V ===")
|
|
||||||
decode_k_buffer = offload_engine.decode_k_buffer.clone().cpu()
|
|
||||||
for step in range(NUM_DECODE_TOKENS):
|
|
||||||
for layer_id in [0, 17, 35]: # Sample a few layers
|
|
||||||
# Find captured K for this step and layer
|
|
||||||
for c in decode_captures:
|
|
||||||
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
|
||||||
captured_k = c['k'].squeeze(0) # [kv_heads, head_dim]
|
|
||||||
buffer_k = decode_k_buffer[layer_id, step] # [kv_heads, head_dim]
|
|
||||||
diff = (captured_k - buffer_k).abs().max().item()
|
|
||||||
print(f"Step {step}, Layer {layer_id}: captured vs buffer max_diff={diff:.6f}")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Debug: Verify that decode_buffer slices match concatenated captures
|
|
||||||
print("\n=== DEBUG: Verifying decode_buffer slices ===")
|
|
||||||
for layer_id in [0]:
|
|
||||||
for decode_step in [1, 2]: # Check steps that use multiple tokens
|
|
||||||
# Build expected slice from captures
|
|
||||||
expected_k_list = []
|
|
||||||
for step in range(decode_step + 1):
|
|
||||||
for c in decode_captures:
|
|
||||||
if c['layer_id'] == layer_id and c['decode_step'] == step:
|
|
||||||
expected_k_list.append(c['k'].squeeze(0)) # [kv_heads, head_dim]
|
|
||||||
break
|
|
||||||
if expected_k_list:
|
|
||||||
expected_k = torch.stack(expected_k_list, dim=0) # [num_tokens, kv_heads, head_dim]
|
|
||||||
buffer_slice = decode_k_buffer[layer_id, 0:decode_step+1]
|
|
||||||
diff = (expected_k - buffer_slice).abs().max().item()
|
|
||||||
print(f"Decode step {decode_step}, Layer {layer_id}: buffer slice vs expected max_diff={diff:.6f}")
|
|
||||||
# Print first values
|
|
||||||
print(f" Buffer[0,0,0]={buffer_slice[0,0,0].item():.6f}, Expected[0,0,0]={expected_k[0,0,0].item():.6f}")
|
|
||||||
if decode_step >= 1:
|
|
||||||
print(f" Buffer[1,0,0]={buffer_slice[1,0,0].item():.6f}, Expected[1,0,0]={expected_k[1,0,0].item():.6f}")
|
|
||||||
|
|
||||||
# Debug: Print expected K value for block 0, layer 0 (to compare with actual loading)
|
|
||||||
print("\n=== DEBUG: Expected K values for block 0, layer 0 ===")
|
|
||||||
for c in prefill_captures:
|
|
||||||
if c['layer_id'] == 0 and c['chunk_idx'] == 0:
|
|
||||||
print(f"Captured K[0,0,0] for layer 0, chunk 0: {c['k'][0,0,0].item():.6f}")
|
|
||||||
break
|
|
||||||
print(f"CPU cache K[0,0,0,0,0] for layer 0, block 0: {k_cache_cpu[0,0,0,0,0].item():.6f}")
|
|
||||||
|
|
||||||
# Debug: Compare CPU cache with captured prefill K/V
|
|
||||||
print("\n=== DEBUG: Comparing CPU cache with captured prefill K/V ===")
|
|
||||||
for chunk_idx in [0, 7, 15]: # Sample a few chunks
|
|
||||||
for layer_id in [0, 17, 35]: # Sample a few layers
|
|
||||||
# Find captured K for this chunk and layer
|
|
||||||
for c in prefill_captures:
|
|
||||||
if c['layer_id'] == layer_id and c['chunk_idx'] == chunk_idx:
|
|
||||||
captured_k = c['k'] # [seq_len, kv_heads, head_dim]
|
|
||||||
cpu_cache_k = k_cache_cpu[layer_id, chunk_idx, :captured_k.shape[0]]
|
|
||||||
diff = (captured_k - cpu_cache_k).abs().max().item()
|
|
||||||
print(f"Chunk {chunk_idx}, Layer {layer_id}: captured vs CPU cache max_diff={diff:.6f}")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Debug: Get cpu_block_table to check order
|
|
||||||
kvcache_manager = llm.model_runner.kvcache_manager
|
|
||||||
# Find the sequence (it should still exist)
|
|
||||||
from nanovllm.engine.sequence import Sequence
|
|
||||||
for attr_name in ['sequences', '_sequences', 'active_sequences']:
|
|
||||||
if hasattr(kvcache_manager, attr_name):
|
|
||||||
print(f"Found {attr_name}")
|
|
||||||
break
|
|
||||||
|
|
||||||
# Try to get cpu_block_table through a different way
|
|
||||||
print(f"\n=== DEBUG: CPU block order ===")
|
|
||||||
# For each prefill capture, check which CPU block it ended up in
|
|
||||||
for chunk_idx in range(num_prefill_chunks):
|
|
||||||
for c in prefill_captures:
|
|
||||||
if c['layer_id'] == 0 and c['chunk_idx'] == chunk_idx:
|
|
||||||
# Check if this chunk's K matches any CPU block
|
|
||||||
captured_k_first = c['k'][0, 0, 0].item()
|
|
||||||
for block_id in range(num_prefill_chunks):
|
|
||||||
cpu_k_first = k_cache_cpu[0, block_id, 0, 0, 0].item()
|
|
||||||
if abs(captured_k_first - cpu_k_first) < 1e-6:
|
|
||||||
print(f"Chunk {chunk_idx} -> CPU block {block_id}")
|
|
||||||
break
|
|
||||||
break
|
|
||||||
|
|
||||||
# Debug: Check reference vs actual for decode steps 0 and 1
|
|
||||||
# Also compute partial references (prefill only, decode only) to isolate the bug
|
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
|
||||||
for decode_step in [0, 1]:
|
|
||||||
print(f"\n=== DEBUG: Reference vs Actual for layer 0, decode {decode_step} ===")
|
|
||||||
layer_id = 0
|
|
||||||
# Find the capture
|
|
||||||
for c in decode_captures:
|
|
||||||
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
|
|
||||||
q = c['q'].cuda() # [1, num_heads, head_dim]
|
|
||||||
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
|
|
||||||
|
|
||||||
# Build prefill K/V per-block for block-by-block reference
|
|
||||||
prefill_k_blocks = []
|
|
||||||
prefill_v_blocks = []
|
|
||||||
for cidx in range(num_prefill_chunks):
|
|
||||||
for pc in prefill_captures:
|
|
||||||
if pc['layer_id'] == layer_id and pc['chunk_idx'] == cidx:
|
|
||||||
prefill_k_blocks.append(pc['k'].cuda().unsqueeze(0)) # [1, block_size, kv_heads, head_dim]
|
|
||||||
prefill_v_blocks.append(pc['v'].cuda().unsqueeze(0))
|
|
||||||
break
|
|
||||||
|
|
||||||
# Build decode K/V
|
|
||||||
decode_k_list = []
|
|
||||||
decode_v_list = []
|
|
||||||
for step in range(decode_step + 1):
|
|
||||||
for dc in decode_captures:
|
|
||||||
if dc['layer_id'] == layer_id and dc['decode_step'] == step:
|
|
||||||
decode_k_list.append(dc['k'].cuda())
|
|
||||||
decode_v_list.append(dc['v'].cuda())
|
|
||||||
break
|
|
||||||
|
|
||||||
full_prefill_k = torch.cat([kb.squeeze(0) for kb in prefill_k_blocks], dim=0).unsqueeze(0)
|
|
||||||
full_prefill_v = torch.cat([vb.squeeze(0) for vb in prefill_v_blocks], dim=0).unsqueeze(0)
|
|
||||||
full_decode_k = torch.cat(decode_k_list, dim=0).unsqueeze(0)
|
|
||||||
full_decode_v = torch.cat(decode_v_list, dim=0).unsqueeze(0)
|
|
||||||
|
|
||||||
full_k = torch.cat([full_prefill_k, full_decode_k], dim=1)
|
|
||||||
full_v = torch.cat([full_prefill_v, full_decode_v], dim=1)
|
|
||||||
|
|
||||||
print(f"Q shape: {q_batched.shape}")
|
|
||||||
print(f"Prefill K shape: {full_prefill_k.shape}")
|
|
||||||
print(f"Decode K shape: {full_decode_k.shape}")
|
|
||||||
print(f"Full K shape: {full_k.shape}")
|
|
||||||
print(f"Total tokens: prefill={num_prefill_chunks * BLOCK_SIZE}, decode={decode_step + 1}")
|
|
||||||
|
|
||||||
# Reference output (single attention over all)
|
|
||||||
ref_output = flash_attn_func(
|
|
||||||
q_batched, full_k, full_v,
|
|
||||||
softmax_scale=scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Chunked reference: prefill attention + decode attention + merge
|
|
||||||
prefill_o, prefill_lse = flash_attn_with_lse(
|
|
||||||
q_batched, full_prefill_k, full_prefill_v,
|
|
||||||
softmax_scale=scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
decode_o, decode_lse = flash_attn_with_lse(
|
|
||||||
q_batched, full_decode_k, full_decode_v,
|
|
||||||
softmax_scale=scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
chunked_output, _ = merge_attention_outputs(prefill_o, prefill_lse, decode_o, decode_lse)
|
|
||||||
|
|
||||||
# Block-by-block reference (simulating ring buffer pipeline)
|
|
||||||
block_o_acc, block_lse_acc = None, None
|
|
||||||
for bidx, (kb, vb) in enumerate(zip(prefill_k_blocks, prefill_v_blocks)):
|
|
||||||
o_blk, lse_blk = flash_attn_with_lse(q_batched, kb, vb, softmax_scale=scale, causal=False)
|
|
||||||
if block_o_acc is None:
|
|
||||||
block_o_acc, block_lse_acc = o_blk, lse_blk
|
|
||||||
else:
|
|
||||||
block_o_acc, block_lse_acc = merge_attention_outputs(block_o_acc, block_lse_acc, o_blk, lse_blk)
|
|
||||||
|
|
||||||
# Compare block-by-block vs single
|
|
||||||
block_vs_single_diff = (block_o_acc - prefill_o).abs().max().item()
|
|
||||||
print(f"Block-by-block vs Single max_diff: {block_vs_single_diff:.6f}")
|
|
||||||
|
|
||||||
# Compare full reference vs chunked reference
|
|
||||||
ref_vs_chunked_diff = (ref_output - chunked_output).abs().max().item()
|
|
||||||
print(f"Reference vs Chunked-reference max_diff: {ref_vs_chunked_diff:.6f}")
|
|
||||||
|
|
||||||
ref_output = ref_output.squeeze(0).squeeze(0).cpu()
|
|
||||||
chunked_output_cpu = chunked_output.squeeze(0).squeeze(0).cpu()
|
|
||||||
|
|
||||||
# Actual output
|
|
||||||
actual_output = c['output'].squeeze(0)
|
|
||||||
if actual_output.dim() == 3:
|
|
||||||
actual_output = actual_output.squeeze(0)
|
|
||||||
|
|
||||||
diff_ref = (actual_output - ref_output).abs()
|
|
||||||
diff_chunked = (actual_output - chunked_output_cpu).abs()
|
|
||||||
print(f"Actual vs Reference max_diff: {diff_ref.max().item():.6f}")
|
|
||||||
print(f"Actual vs Chunked-ref max_diff: {diff_chunked.max().item():.6f}")
|
|
||||||
break
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Verify decode outputs
|
|
||||||
all_passed = True
|
|
||||||
|
|
||||||
for c in decode_captures:
|
|
||||||
layer_id = c['layer_id']
|
|
||||||
decode_step = c['decode_step']
|
|
||||||
|
|
||||||
ref_output = compute_decode_reference(
|
|
||||||
layer_id, decode_step, scale,
|
|
||||||
k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks
|
|
||||||
)
|
|
||||||
if ref_output is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
actual_output = c['output'].squeeze(0)
|
|
||||||
if actual_output.dim() == 3:
|
|
||||||
actual_output = actual_output.squeeze(0)
|
|
||||||
|
|
||||||
diff = (actual_output - ref_output).abs()
|
|
||||||
max_diff = diff.max().item()
|
|
||||||
|
|
||||||
passed = max_diff < 1e-1
|
|
||||||
all_passed = all_passed and passed
|
|
||||||
|
|
||||||
if not passed:
|
|
||||||
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
|
|
||||||
|
|
||||||
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")
|
|
||||||
@@ -1,196 +0,0 @@
|
|||||||
"""
|
|
||||||
Correctness test for chunked prefill attention.
|
|
||||||
|
|
||||||
Captures Q and output during inference, then computes reference using
|
|
||||||
CPU KV cache with standard flash attention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from random import randint, seed
|
|
||||||
from typing import Dict, List
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from nanovllm.utils.context import get_context
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
|
|
||||||
# Config
|
|
||||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
|
||||||
MAX_MODEL_LEN = 128 * 1024
|
|
||||||
NUM_GPU_BLOCKS = 2
|
|
||||||
INPUT_LEN = 16 * 1024
|
|
||||||
BLOCK_SIZE = 1024
|
|
||||||
|
|
||||||
# State - capture Q and output for each (layer, chunk)
|
|
||||||
captures: List[Dict] = []
|
|
||||||
|
|
||||||
|
|
||||||
def make_ones_injection_hook():
|
|
||||||
"""Inject Q=K=V=1.0 for deterministic testing."""
|
|
||||||
def hook(module, inputs):
|
|
||||||
ctx = get_context()
|
|
||||||
if not ctx.is_prefill:
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
q, k, v = inputs[0], inputs[1], inputs[2]
|
|
||||||
q_ones = torch.ones_like(q)
|
|
||||||
k_ones = torch.ones_like(k)
|
|
||||||
v_ones = torch.ones_like(v)
|
|
||||||
return (q_ones, k_ones, v_ones) + inputs[3:]
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def make_capture_hook(layer_id: int):
|
|
||||||
"""Capture Q and output during prefill."""
|
|
||||||
def hook(module, inputs, output):
|
|
||||||
ctx = get_context()
|
|
||||||
if not ctx.is_prefill:
|
|
||||||
return
|
|
||||||
|
|
||||||
q, k, v = inputs
|
|
||||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
||||||
|
|
||||||
captures.append({
|
|
||||||
'layer_id': layer_id,
|
|
||||||
'chunk_idx': chunk_idx,
|
|
||||||
'q': q.clone().cpu(),
|
|
||||||
'k': k.clone().cpu(),
|
|
||||||
'v': v.clone().cpu(),
|
|
||||||
'output': output.clone().cpu(),
|
|
||||||
})
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def compute_reference(layer_id: int, chunk_idx: int, scale: float,
|
|
||||||
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
|
|
||||||
block_size: int) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute reference output using CPU KV cache and standard flash attention.
|
|
||||||
|
|
||||||
Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention,
|
|
||||||
then extracts output for the current chunk.
|
|
||||||
"""
|
|
||||||
# Get all captures for this layer up to chunk_idx
|
|
||||||
layer_captures = [c for c in captures
|
|
||||||
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
|
|
||||||
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
|
|
||||||
|
|
||||||
if not layer_captures:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Collect Q from captures, K/V from CPU cache
|
|
||||||
all_q = []
|
|
||||||
all_k = []
|
|
||||||
all_v = []
|
|
||||||
chunk_lengths = []
|
|
||||||
|
|
||||||
for c in layer_captures:
|
|
||||||
cidx = c['chunk_idx']
|
|
||||||
q = c['q'].cuda() # [seqlen, nheads, headdim]
|
|
||||||
all_q.append(q)
|
|
||||||
chunk_lengths.append(q.shape[0])
|
|
||||||
|
|
||||||
# Get K, V from CPU cache (already offloaded during prefill)
|
|
||||||
# CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim]
|
|
||||||
k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
|
|
||||||
v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
|
|
||||||
all_k.append(k)
|
|
||||||
all_v.append(v)
|
|
||||||
|
|
||||||
# Concatenate
|
|
||||||
full_q = torch.cat(all_q, dim=0)
|
|
||||||
full_k = torch.cat(all_k, dim=0)
|
|
||||||
full_v = torch.cat(all_v, dim=0)
|
|
||||||
total_len = full_q.shape[0]
|
|
||||||
|
|
||||||
# Run standard causal flash attention
|
|
||||||
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
|
|
||||||
full_o = flash_attn_varlen_func(
|
|
||||||
full_q, full_k, full_v,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=total_len,
|
|
||||||
max_seqlen_k=total_len,
|
|
||||||
softmax_scale=scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Extract output for current chunk
|
|
||||||
start_pos = sum(chunk_lengths[:-1])
|
|
||||||
end_pos = sum(chunk_lengths)
|
|
||||||
return full_o[start_pos:end_pos].cpu()
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
MODEL_PATH,
|
|
||||||
enforce_eager=True,
|
|
||||||
max_model_len=MAX_MODEL_LEN,
|
|
||||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
kvcache_block_size=BLOCK_SIZE,
|
|
||||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
|
||||||
dtype="float16",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get model info
|
|
||||||
num_layers = len(llm.model_runner.model.model.layers)
|
|
||||||
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
|
|
||||||
scale = head_dim ** -0.5
|
|
||||||
|
|
||||||
# Register hooks
|
|
||||||
hooks = []
|
|
||||||
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
|
|
||||||
# Pre-hook: inject all ones for Q, K, V
|
|
||||||
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
|
|
||||||
# hooks.append(pre_hook)
|
|
||||||
# Post-hook: capture Q, K, V, output
|
|
||||||
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
|
|
||||||
hooks.append(post_hook)
|
|
||||||
|
|
||||||
# Run inference
|
|
||||||
seed(42)
|
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
|
||||||
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False)
|
|
||||||
|
|
||||||
# Remove hooks
|
|
||||||
for hook in hooks:
|
|
||||||
hook.remove()
|
|
||||||
|
|
||||||
# Get CPU cache reference
|
|
||||||
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
|
||||||
k_cache_cpu = offload_engine.k_cache_cpu.clone()
|
|
||||||
v_cache_cpu = offload_engine.v_cache_cpu.clone()
|
|
||||||
|
|
||||||
# Verify: compare actual output with reference computed from CPU cache
|
|
||||||
all_passed = True
|
|
||||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
|
||||||
|
|
||||||
for idx,c in enumerate(captures):
|
|
||||||
layer_id = c['layer_id']
|
|
||||||
chunk_idx = c['chunk_idx']
|
|
||||||
|
|
||||||
# Skip chunk 0 (no previous KV to load)
|
|
||||||
if chunk_idx == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE)
|
|
||||||
if ref_output is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
actual_output = c['output']
|
|
||||||
diff = (actual_output - ref_output).abs()
|
|
||||||
max_diff = diff.max().item()
|
|
||||||
|
|
||||||
passed = max_diff < 1e-1 # float16 tolerance
|
|
||||||
all_passed = all_passed and passed
|
|
||||||
|
|
||||||
if not passed:
|
|
||||||
print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
|
|
||||||
__import__('pdb').set_trace()
|
|
||||||
|
|
||||||
print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}")
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
"""
|
|
||||||
Test KV cache offload correctness using debug hooks.
|
|
||||||
Injects distinctive K/V values, verifies loaded tensors match expected patterns.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
|
||||||
|
|
||||||
import inspect
|
|
||||||
from random import randint, seed
|
|
||||||
from typing import Dict, List
|
|
||||||
import torch
|
|
||||||
from torch import Tensor
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from nanovllm.utils.context import get_context
|
|
||||||
|
|
||||||
# Config
|
|
||||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
|
||||||
MAX_MODEL_LEN = 32 * 1024
|
|
||||||
NUM_GPU_BLOCKS = 4
|
|
||||||
INPUT_LEN = 32 * 1024
|
|
||||||
BLOCK_SIZE = 1024
|
|
||||||
|
|
||||||
# State
|
|
||||||
load_log: List[Dict] = []
|
|
||||||
current_chunk: List[int] = [0]
|
|
||||||
|
|
||||||
|
|
||||||
def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None:
|
|
||||||
"""Record loaded tensor values for layer 0."""
|
|
||||||
if layer_id != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Go up the stack to find kvcache_manager and print k_cache_gpu[*][0,0,0] for all slots
|
|
||||||
frame = inspect.currentframe()
|
|
||||||
try:
|
|
||||||
caller_frame = frame.f_back
|
|
||||||
if caller_frame is not None:
|
|
||||||
local_vars = caller_frame.f_locals
|
|
||||||
if 'self' in local_vars:
|
|
||||||
self_obj = local_vars['self']
|
|
||||||
if hasattr(self_obj, 'k_cache_gpu'):
|
|
||||||
num_slots = self_obj.k_cache_gpu.shape[0]
|
|
||||||
vals = []
|
|
||||||
for i in range(num_slots):
|
|
||||||
v = self_obj.k_cache_gpu[i][0,0,0].item()
|
|
||||||
if i == slot_idx:
|
|
||||||
vals.append(f"[{v}]")
|
|
||||||
else:
|
|
||||||
vals.append(str(v))
|
|
||||||
print(f"[DEBUG] k_cache_gpu[0..{num_slots-1}][0,0,0] = [{', '.join(vals)}]")
|
|
||||||
finally:
|
|
||||||
del frame
|
|
||||||
|
|
||||||
load_log.append({
|
|
||||||
"chunk_idx": current_chunk[0],
|
|
||||||
"cpu_block_id": cpu_block_id,
|
|
||||||
"k_value": k.float().mean().item(),
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
def make_pattern_injection_hook(layer_id):
|
|
||||||
"""Inject K = chunk_idx + 1, V = -(chunk_idx + 1) for layer 0."""
|
|
||||||
def hook(module, inputs):
|
|
||||||
ctx = get_context()
|
|
||||||
if not ctx.is_prefill or layer_id != 0:
|
|
||||||
return inputs
|
|
||||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
||||||
current_chunk[0] = chunk_idx
|
|
||||||
if len(inputs) >= 3:
|
|
||||||
q, k, v = inputs[0], inputs[1], inputs[2]
|
|
||||||
k_new = torch.full_like(k, float(chunk_idx + 1))
|
|
||||||
v_new = torch.full_like(v, float(-(chunk_idx + 1)))
|
|
||||||
return (q, k_new, v_new) + inputs[3:]
|
|
||||||
return inputs
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def verify() -> bool:
|
|
||||||
"""Verify blocks loaded in correct order with correct K values."""
|
|
||||||
chunk_loads: Dict[int, List[tuple]] = {}
|
|
||||||
for log in load_log:
|
|
||||||
chunk = log["chunk_idx"]
|
|
||||||
if chunk not in chunk_loads:
|
|
||||||
chunk_loads[chunk] = []
|
|
||||||
chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"]))
|
|
||||||
|
|
||||||
for chunk, loads in chunk_loads.items():
|
|
||||||
expected_blocks = list(range(chunk))
|
|
||||||
actual_blocks = [b for b, _ in loads]
|
|
||||||
k_values = [k for _, k in loads]
|
|
||||||
expected_k = [float(b + 1) for b in expected_blocks]
|
|
||||||
|
|
||||||
if actual_blocks != expected_blocks:
|
|
||||||
return False
|
|
||||||
if not all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k)):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
# Main
|
|
||||||
llm = LLM(
|
|
||||||
MODEL_PATH,
|
|
||||||
enforce_eager=True,
|
|
||||||
max_model_len=MAX_MODEL_LEN,
|
|
||||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
kvcache_block_size=BLOCK_SIZE,
|
|
||||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
|
||||||
dtype="float16",
|
|
||||||
)
|
|
||||||
|
|
||||||
offload_engine = llm.model_runner.kvcache_manager.offload_engine
|
|
||||||
offload_engine.enable_debug_mode()
|
|
||||||
offload_engine.register_debug_hook(debug_load_hook)
|
|
||||||
|
|
||||||
hooks = []
|
|
||||||
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
|
|
||||||
hooks.append(decoder_layer.self_attn.attn.register_forward_pre_hook(
|
|
||||||
make_pattern_injection_hook(layer_idx)
|
|
||||||
))
|
|
||||||
|
|
||||||
seed(42)
|
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
|
||||||
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1), use_tqdm=False)
|
|
||||||
|
|
||||||
for hook in hooks:
|
|
||||||
hook.remove()
|
|
||||||
offload_engine.remove_debug_hook(debug_load_hook)
|
|
||||||
offload_engine.disable_debug_mode()
|
|
||||||
|
|
||||||
# Verify
|
|
||||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
|
||||||
expected_loads = num_chunks * (num_chunks - 1) // 2
|
|
||||||
passed = len(load_log) == expected_loads and verify()
|
|
||||||
|
|
||||||
print(f"test_debug_verification: {'PASSED' if passed else 'FAILED'}")
|
|
||||||
@@ -1,276 +0,0 @@
|
|||||||
"""
|
|
||||||
Test script for flash_attn_with_kvcache based chunked prefill.
|
|
||||||
|
|
||||||
Verifies that chunked prefill produces identical results to full attention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from flash_attn import flash_attn_func, flash_attn_with_kvcache
|
|
||||||
|
|
||||||
|
|
||||||
def chunk_prefill(q_full, k_full, v_full, k_cache, v_cache, cache_seqlens, chunk_size):
|
|
||||||
"""
|
|
||||||
Chunked prefill using flash_attn_with_kvcache.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q_full, k_full, v_full: [batch, total_seq_len, heads, head_dim]
|
|
||||||
k_cache, v_cache: [batch, max_seq_len, kv_heads, head_dim]
|
|
||||||
cache_seqlens: [batch] - current cache lengths
|
|
||||||
chunk_size: size of each chunk
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
output: [batch, total_seq_len, heads, head_dim]
|
|
||||||
"""
|
|
||||||
total_len = q_full.shape[1]
|
|
||||||
outputs = []
|
|
||||||
|
|
||||||
for start in range(0, total_len, chunk_size):
|
|
||||||
end = min(start + chunk_size, total_len)
|
|
||||||
|
|
||||||
q_chunk = q_full[:, start:end]
|
|
||||||
k_chunk = k_full[:, start:end]
|
|
||||||
v_chunk = v_full[:, start:end]
|
|
||||||
|
|
||||||
out = flash_attn_with_kvcache(
|
|
||||||
q_chunk,
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
k=k_chunk,
|
|
||||||
v=v_chunk,
|
|
||||||
cache_seqlens=cache_seqlens,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
outputs.append(out)
|
|
||||||
|
|
||||||
cache_seqlens += (end - start)
|
|
||||||
|
|
||||||
return torch.cat(outputs, dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def reference_attention(q, k, v):
|
|
||||||
"""Standard flash attention as reference."""
|
|
||||||
return flash_attn_func(q, k, v, causal=True)
|
|
||||||
|
|
||||||
|
|
||||||
def test_chunked_prefill_correctness():
|
|
||||||
"""Test that chunked prefill matches full attention."""
|
|
||||||
|
|
||||||
batch_size = 1
|
|
||||||
num_heads = 32
|
|
||||||
num_kv_heads = 8 # GQA
|
|
||||||
head_dim = 128
|
|
||||||
max_seq_len = 131072 # 128K
|
|
||||||
|
|
||||||
test_configs = [
|
|
||||||
(1024, 256), # 1K tokens, 256 chunk
|
|
||||||
(2048, 512), # 2K tokens, 512 chunk
|
|
||||||
(4096, 1024), # 4K tokens, 1K chunk
|
|
||||||
(4096, 2048), # 4K tokens, 2K chunk (2 chunks)
|
|
||||||
(8192, 2048), # 8K tokens, 2K chunk (4 chunks)
|
|
||||||
(16384, 4096), # 16K tokens, 4K chunk
|
|
||||||
(32768, 4096), # 32K tokens, 4K chunk
|
|
||||||
(65536, 8192), # 64K tokens, 8K chunk
|
|
||||||
(131072, 8192), # 128K tokens, 8K chunk (16 chunks)
|
|
||||||
]
|
|
||||||
|
|
||||||
for seq_len, chunk_size in test_configs:
|
|
||||||
print(f"\nTesting seq_len={seq_len}, chunk_size={chunk_size}...")
|
|
||||||
|
|
||||||
# Generate random input
|
|
||||||
torch.manual_seed(42)
|
|
||||||
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
|
|
||||||
# Expand K/V for non-GQA reference
|
|
||||||
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
|
||||||
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
|
||||||
|
|
||||||
# Reference: full attention
|
|
||||||
ref_out = reference_attention(q, k_expanded, v_expanded)
|
|
||||||
|
|
||||||
# Chunked prefill with KV cache
|
|
||||||
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
|
|
||||||
|
|
||||||
chunked_out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
|
|
||||||
|
|
||||||
# Compare
|
|
||||||
max_diff = (ref_out - chunked_out).abs().max().item()
|
|
||||||
mean_diff = (ref_out - chunked_out).abs().mean().item()
|
|
||||||
|
|
||||||
# Verify cache was filled correctly
|
|
||||||
assert cache_seqlens[0].item() == seq_len, f"Cache seqlen mismatch: {cache_seqlens[0].item()} != {seq_len}"
|
|
||||||
|
|
||||||
# Check K/V cache content
|
|
||||||
k_cache_diff = (k_cache[:, :seq_len] - k).abs().max().item()
|
|
||||||
v_cache_diff = (v_cache[:, :seq_len] - v).abs().max().item()
|
|
||||||
|
|
||||||
print(f" Output max_diff: {max_diff:.6f}, mean_diff: {mean_diff:.6f}")
|
|
||||||
print(f" KV cache diff: k={k_cache_diff:.6f}, v={v_cache_diff:.6f}")
|
|
||||||
|
|
||||||
# Tolerance for fp16
|
|
||||||
tolerance = 1e-2
|
|
||||||
if max_diff < tolerance:
|
|
||||||
print(f" PASSED")
|
|
||||||
else:
|
|
||||||
print(f" FAILED (max_diff {max_diff:.6f} >= {tolerance})")
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def test_incremental_decode():
|
|
||||||
"""Test that decode after chunked prefill works correctly."""
|
|
||||||
|
|
||||||
batch_size = 1
|
|
||||||
num_heads = 32
|
|
||||||
num_kv_heads = 8
|
|
||||||
head_dim = 128
|
|
||||||
max_seq_len = 8192
|
|
||||||
|
|
||||||
prefill_len = 2048
|
|
||||||
chunk_size = 512
|
|
||||||
num_decode_steps = 10
|
|
||||||
|
|
||||||
print(f"\nTesting incremental decode after chunked prefill...")
|
|
||||||
print(f" Prefill: {prefill_len} tokens, chunk_size={chunk_size}")
|
|
||||||
print(f" Decode: {num_decode_steps} steps")
|
|
||||||
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
# Prefill phase
|
|
||||||
q_prefill = torch.randn(batch_size, prefill_len, num_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
k_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
v_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
|
|
||||||
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
|
|
||||||
|
|
||||||
# Run chunked prefill
|
|
||||||
prefill_out = chunk_prefill(q_prefill, k_prefill, v_prefill,
|
|
||||||
k_cache, v_cache, cache_seqlens, chunk_size)
|
|
||||||
|
|
||||||
print(f" After prefill: cache_seqlens={cache_seqlens[0].item()}")
|
|
||||||
|
|
||||||
# Decode phase - one token at a time
|
|
||||||
for step in range(num_decode_steps):
|
|
||||||
q_decode = torch.randn(batch_size, 1, num_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
k_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
v_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
|
|
||||||
decode_out = flash_attn_with_kvcache(
|
|
||||||
q_decode,
|
|
||||||
k_cache,
|
|
||||||
v_cache,
|
|
||||||
k=k_decode,
|
|
||||||
v=v_decode,
|
|
||||||
cache_seqlens=cache_seqlens,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
cache_seqlens += 1
|
|
||||||
|
|
||||||
assert decode_out.shape == (batch_size, 1, num_heads, head_dim)
|
|
||||||
|
|
||||||
expected_len = prefill_len + num_decode_steps
|
|
||||||
actual_len = cache_seqlens[0].item()
|
|
||||||
|
|
||||||
print(f" After decode: cache_seqlens={actual_len}")
|
|
||||||
|
|
||||||
if actual_len == expected_len:
|
|
||||||
print(f" PASSED")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print(f" FAILED: expected {expected_len}, got {actual_len}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def test_batch_processing():
|
|
||||||
"""Test chunked prefill with batch > 1."""
|
|
||||||
|
|
||||||
batch_size = 4
|
|
||||||
num_heads = 32
|
|
||||||
num_kv_heads = 8
|
|
||||||
head_dim = 128
|
|
||||||
max_seq_len = 4096
|
|
||||||
seq_len = 2048
|
|
||||||
chunk_size = 512
|
|
||||||
|
|
||||||
print(f"\nTesting batch processing (batch_size={batch_size})...")
|
|
||||||
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
|
|
||||||
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=torch.float16, device='cuda')
|
|
||||||
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
|
|
||||||
|
|
||||||
out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
|
|
||||||
|
|
||||||
# Verify all batches have correct cache length
|
|
||||||
assert (cache_seqlens == seq_len).all(), f"Cache seqlens mismatch: {cache_seqlens}"
|
|
||||||
assert out.shape == (batch_size, seq_len, num_heads, head_dim)
|
|
||||||
|
|
||||||
# Compare with reference for each batch item
|
|
||||||
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
|
||||||
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
|
|
||||||
ref_out = reference_attention(q, k_expanded, v_expanded)
|
|
||||||
|
|
||||||
max_diff = (ref_out - out).abs().max().item()
|
|
||||||
|
|
||||||
print(f" Output shape: {out.shape}")
|
|
||||||
print(f" Max diff vs reference: {max_diff:.6f}")
|
|
||||||
|
|
||||||
if max_diff < 1e-2:
|
|
||||||
print(f" PASSED")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print(f" FAILED")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test Script
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("=" * 60)
|
|
||||||
print("Testing flash_attn_with_kvcache chunked prefill")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
all_passed = True
|
|
||||||
|
|
||||||
all_passed &= test_chunked_prefill_correctness()
|
|
||||||
all_passed &= test_incremental_decode()
|
|
||||||
all_passed &= test_batch_processing()
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
if all_passed:
|
|
||||||
print("test_flash_attn_kvcache: ALL TESTS PASSED")
|
|
||||||
else:
|
|
||||||
print("test_flash_attn_kvcache: SOME TESTS FAILED")
|
|
||||||
print("=" * 60)
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
"""
|
|
||||||
Test FlashInfer chunked attention with CPU offload.
|
|
||||||
|
|
||||||
Uses single_prefill_with_kv_cache + merge_state for chunked KV processing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import flashinfer
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Core Functions
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk_size, kv_chunk_size):
|
|
||||||
"""
|
|
||||||
Chunked causal attention with KV on CPU.
|
|
||||||
|
|
||||||
q: [seq_q, num_heads, head_dim] on GPU
|
|
||||||
k_cpu, v_cpu: [seq_kv, num_kv_heads, head_dim] on CPU
|
|
||||||
"""
|
|
||||||
seq_q = q.shape[0]
|
|
||||||
seq_kv = k_cpu.shape[0]
|
|
||||||
final_outputs = []
|
|
||||||
|
|
||||||
for q_start in range(0, seq_q, q_chunk_size):
|
|
||||||
q_end = min(q_start + q_chunk_size, seq_q)
|
|
||||||
q_chunk = q[q_start:q_end]
|
|
||||||
|
|
||||||
merged_output = None
|
|
||||||
merged_lse = None
|
|
||||||
|
|
||||||
for kv_start in range(0, seq_kv, kv_chunk_size):
|
|
||||||
kv_end = min(kv_start + kv_chunk_size, seq_kv)
|
|
||||||
|
|
||||||
if kv_start >= q_end:
|
|
||||||
continue
|
|
||||||
|
|
||||||
k_chunk = k_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
|
|
||||||
v_chunk = v_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
|
|
||||||
|
|
||||||
causal = not (kv_end <= q_start)
|
|
||||||
partial_out, partial_lse = flashinfer.single_prefill_with_kv_cache(
|
|
||||||
q_chunk, k_chunk, v_chunk,
|
|
||||||
causal=causal,
|
|
||||||
return_lse=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if merged_output is None:
|
|
||||||
merged_output, merged_lse = partial_out, partial_lse
|
|
||||||
else:
|
|
||||||
merged_output, merged_lse = flashinfer.merge_state(
|
|
||||||
merged_output, merged_lse,
|
|
||||||
partial_out, partial_lse,
|
|
||||||
)
|
|
||||||
|
|
||||||
final_outputs.append(merged_output)
|
|
||||||
|
|
||||||
return torch.cat(final_outputs, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test Script
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("Testing FlashInfer chunked attention with CPU offload")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
num_heads = 32
|
|
||||||
num_kv_heads = 8
|
|
||||||
head_dim = 128
|
|
||||||
|
|
||||||
test_configs = [
|
|
||||||
(32768, 8192, 8192), # 32K tokens
|
|
||||||
(65536, 8192, 8192), # 64K tokens
|
|
||||||
(131072, 16384, 16384), # 128K tokens
|
|
||||||
# (262144, 16384, 16384), # 256K tokens (slow)
|
|
||||||
# (524288, 16384, 16384), # 512K tokens (slow)
|
|
||||||
]
|
|
||||||
|
|
||||||
for seq_len, q_chunk, kv_chunk in test_configs:
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
q = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float16, device='cuda')
|
|
||||||
k_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
|
|
||||||
v_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
|
|
||||||
|
|
||||||
# Chunked result
|
|
||||||
chunked_out = chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk, kv_chunk)
|
|
||||||
|
|
||||||
# Reference
|
|
||||||
k_gpu = k_cpu.to('cuda')
|
|
||||||
v_gpu = v_cpu.to('cuda')
|
|
||||||
ref_out = flashinfer.single_prefill_with_kv_cache(q, k_gpu, v_gpu, causal=True)
|
|
||||||
|
|
||||||
max_diff = (ref_out - chunked_out).abs().max().item()
|
|
||||||
mean_diff = (ref_out - chunked_out).abs().mean().item()
|
|
||||||
|
|
||||||
num_chunks = (seq_len + q_chunk - 1) // q_chunk
|
|
||||||
assert max_diff < 1e-2, f"FAILED: max_diff={max_diff:.6f}"
|
|
||||||
print(f"seq={seq_len//1024}K, chunks={num_chunks}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
|
|
||||||
|
|
||||||
print("\ntest_flashinfer_merge: PASSED")
|
|
||||||
@@ -8,155 +8,12 @@ sequences longer than ~200 tokens. Use --no-offload for correctness testing.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
from utils import generate_needle_prompt, check_needle_answer
|
||||||
# ============================================================
|
|
||||||
# Needle Test Generator
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def generate_needle_prompt(
|
|
||||||
tokenizer,
|
|
||||||
target_length: int,
|
|
||||||
needle_position: float = 0.5,
|
|
||||||
needle_value: str = "7492",
|
|
||||||
use_chat_template: bool = True,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokenizer: HuggingFace tokenizer for length estimation
|
|
||||||
target_length: Target total sequence length in tokens
|
|
||||||
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
|
|
||||||
needle_value: The secret value to hide in the haystack
|
|
||||||
use_chat_template: Whether to use chat template for instruct models
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(prompt, expected_answer): The full prompt and the expected needle value
|
|
||||||
"""
|
|
||||||
# Haystack filler paragraphs (various topics to create realistic context)
|
|
||||||
haystack_paragraphs = [
|
|
||||||
"The weather today is quite pleasant with clear skies and moderate temperatures. "
|
|
||||||
"Many people are enjoying outdoor activities in the park. "
|
|
||||||
"Birds are singing in the trees and children are playing on the swings. ",
|
|
||||||
|
|
||||||
"In the world of technology, new innovations continue to emerge every day. "
|
|
||||||
"Researchers are working on advanced algorithms and computing systems. "
|
|
||||||
"The future of artificial intelligence looks promising with many breakthroughs. ",
|
|
||||||
|
|
||||||
"The history of human civilization spans thousands of years. "
|
|
||||||
"Ancient cultures developed writing, mathematics, and astronomy. "
|
|
||||||
"Trade routes connected distant lands and facilitated cultural exchange. ",
|
|
||||||
|
|
||||||
"Modern cooking combines traditional techniques with new ingredients. "
|
|
||||||
"Chefs around the world experiment with flavors and presentations. "
|
|
||||||
"Food brings people together and creates memorable experiences. ",
|
|
||||||
|
|
||||||
"The ocean covers more than seventy percent of Earth's surface. "
|
|
||||||
"Marine ecosystems support an incredible diversity of life forms. "
|
|
||||||
"Scientists continue to discover new species in the deep sea. ",
|
|
||||||
|
|
||||||
"Music has been a part of human culture since prehistoric times. "
|
|
||||||
"Different genres evolved across various regions and time periods. "
|
|
||||||
"Today, people can access millions of songs through digital platforms. ",
|
|
||||||
|
|
||||||
"Space exploration has revealed many secrets about our universe. "
|
|
||||||
"Telescopes can observe galaxies billions of light years away. "
|
|
||||||
"Future missions aim to establish human presence on other planets. ",
|
|
||||||
|
|
||||||
"The study of languages reveals patterns in human cognition. "
|
|
||||||
"Linguists analyze grammar, semantics, and phonetics across cultures. "
|
|
||||||
"Language continues to evolve with new words and expressions. ",
|
|
||||||
]
|
|
||||||
|
|
||||||
# The needle sentence
|
|
||||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
|
||||||
|
|
||||||
# Question at the end
|
|
||||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
|
||||||
|
|
||||||
# Estimate tokens for fixed parts
|
|
||||||
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
|
||||||
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
|
||||||
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
|
||||||
# Buffer for chat template, special tokens, etc.
|
|
||||||
overhead_tokens = 100 if use_chat_template else 50
|
|
||||||
|
|
||||||
# Available tokens for haystack
|
|
||||||
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
|
||||||
if haystack_target_tokens < 100:
|
|
||||||
raise ValueError(f"target_length {target_length} is too short for needle test")
|
|
||||||
|
|
||||||
# Build haystack by repeating paragraphs
|
|
||||||
haystack_parts = []
|
|
||||||
current_tokens = 0
|
|
||||||
para_idx = 0
|
|
||||||
|
|
||||||
while current_tokens < haystack_target_tokens:
|
|
||||||
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
|
|
||||||
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
|
||||||
if current_tokens + para_tokens > haystack_target_tokens:
|
|
||||||
break
|
|
||||||
haystack_parts.append(para)
|
|
||||||
current_tokens += para_tokens
|
|
||||||
para_idx += 1
|
|
||||||
|
|
||||||
# Calculate needle insertion point
|
|
||||||
needle_idx = int(len(haystack_parts) * needle_position)
|
|
||||||
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
|
||||||
|
|
||||||
# Insert needle
|
|
||||||
haystack_parts.insert(needle_idx, needle)
|
|
||||||
|
|
||||||
# Assemble prompt
|
|
||||||
full_text = "".join(haystack_parts)
|
|
||||||
|
|
||||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
|
||||||
# Use chat template for instruct models
|
|
||||||
# For Qwen3, add /no_think to disable thinking mode
|
|
||||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
|
||||||
]
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Raw text format for base models
|
|
||||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
|
||||||
prompt = full_text + question
|
|
||||||
|
|
||||||
# Verify length
|
|
||||||
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
|
||||||
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
|
||||||
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
|
||||||
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
|
||||||
|
|
||||||
return prompt, needle_value
|
|
||||||
|
|
||||||
|
|
||||||
def check_needle_answer(output_text: str, expected: str) -> bool:
|
|
||||||
"""Check if the model output contains the expected needle value."""
|
|
||||||
import re
|
|
||||||
# Clean output - remove special tokens and whitespace
|
|
||||||
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
|
||||||
output_clean = ' '.join(output_clean.split()).lower()
|
|
||||||
expected_clean = expected.strip().lower()
|
|
||||||
|
|
||||||
# Check if expected value appears in output
|
|
||||||
# Also try to find it as a standalone number
|
|
||||||
if expected_clean in output_clean:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Try to extract numbers and check if expected is among them
|
|
||||||
numbers = re.findall(r'\d+', output_clean)
|
|
||||||
return expected_clean in numbers
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -168,10 +25,14 @@ def run_needle_test(
|
|||||||
max_model_len: int,
|
max_model_len: int,
|
||||||
input_len: int,
|
input_len: int,
|
||||||
num_gpu_blocks: int = 4,
|
num_gpu_blocks: int = 4,
|
||||||
|
block_size: int = 1024,
|
||||||
needle_position: float = 0.5,
|
needle_position: float = 0.5,
|
||||||
needle_value: str = "7492",
|
needle_value: str = "7492",
|
||||||
max_new_tokens: int = 32,
|
max_new_tokens: int = 32,
|
||||||
enable_cpu_offload: bool = False,
|
enable_cpu_offload: bool = False,
|
||||||
|
enable_quest: bool = False,
|
||||||
|
sparse_topk: int = 8,
|
||||||
|
sparse_threshold: int = 4,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -182,15 +43,21 @@ def run_needle_test(
|
|||||||
max_model_len: Maximum model context length
|
max_model_len: Maximum model context length
|
||||||
input_len: Target input sequence length
|
input_len: Target input sequence length
|
||||||
num_gpu_blocks: Number of GPU blocks for offload
|
num_gpu_blocks: Number of GPU blocks for offload
|
||||||
|
block_size: KV cache block size
|
||||||
needle_position: Where to place needle (0.0-1.0)
|
needle_position: Where to place needle (0.0-1.0)
|
||||||
needle_value: The secret value to find
|
needle_value: The secret value to find
|
||||||
max_new_tokens: Maximum tokens to generate
|
max_new_tokens: Maximum tokens to generate
|
||||||
enable_cpu_offload: Enable CPU offload mode
|
enable_cpu_offload: Enable CPU offload mode
|
||||||
|
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
||||||
|
sparse_topk: Top-K blocks for Quest
|
||||||
|
sparse_threshold: Apply sparse only when blocks > threshold
|
||||||
verbose: Print detailed output
|
verbose: Print detailed output
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if test passed, False otherwise
|
True if test passed, False otherwise
|
||||||
"""
|
"""
|
||||||
|
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
print(f"Needle-in-Haystack Test")
|
print(f"Needle-in-Haystack Test")
|
||||||
@@ -198,9 +65,12 @@ def run_needle_test(
|
|||||||
print(f"Model: {model_path}")
|
print(f"Model: {model_path}")
|
||||||
print(f"Max model len: {max_model_len}")
|
print(f"Max model len: {max_model_len}")
|
||||||
print(f"Input length: {input_len}")
|
print(f"Input length: {input_len}")
|
||||||
|
print(f"Block size: {block_size}")
|
||||||
print(f"Needle position: {needle_position:.0%}")
|
print(f"Needle position: {needle_position:.0%}")
|
||||||
print(f"Needle value: {needle_value}")
|
print(f"Needle value: {needle_value}")
|
||||||
print(f"CPU offload: {enable_cpu_offload}")
|
print(f"CPU offload: {enable_cpu_offload}")
|
||||||
|
if enable_cpu_offload:
|
||||||
|
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})")
|
||||||
print(f"{'='*60}\n")
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
# 1. Initialize LLM
|
# 1. Initialize LLM
|
||||||
@@ -209,9 +79,13 @@ def run_needle_test(
|
|||||||
"max_model_len": max_model_len,
|
"max_model_len": max_model_len,
|
||||||
"max_num_batched_tokens": max_model_len,
|
"max_num_batched_tokens": max_model_len,
|
||||||
"enable_cpu_offload": enable_cpu_offload,
|
"enable_cpu_offload": enable_cpu_offload,
|
||||||
|
"kvcache_block_size": block_size,
|
||||||
}
|
}
|
||||||
if enable_cpu_offload:
|
if enable_cpu_offload:
|
||||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||||
|
llm_kwargs["sparse_policy"] = sparse_policy
|
||||||
|
llm_kwargs["sparse_topk_blocks"] = sparse_topk
|
||||||
|
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
llm = LLM(model_path, **llm_kwargs)
|
||||||
|
|
||||||
@@ -263,7 +137,7 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--max-model-len",
|
"--max-model-len",
|
||||||
type=int,
|
type=int,
|
||||||
default=32 * 1024,
|
default=128 * 1024,
|
||||||
help="Maximum model context length"
|
help="Maximum model context length"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -278,6 +152,12 @@ if __name__ == "__main__":
|
|||||||
default=2,
|
default=2,
|
||||||
help="Number of GPU blocks for CPU offload"
|
help="Number of GPU blocks for CPU offload"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--block-size",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="KV cache block size"
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--needle-position",
|
"--needle-position",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -301,6 +181,23 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable CPU offload (has known bug for long sequences)"
|
help="Enable CPU offload (has known bug for long sequences)"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-quest",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable Quest sparse attention (decode-only Top-K selection)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sparse-topk",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Top-K blocks for Quest sparse attention"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sparse-threshold",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Apply sparse only when blocks > threshold"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
passed = run_needle_test(
|
passed = run_needle_test(
|
||||||
@@ -308,10 +205,14 @@ if __name__ == "__main__":
|
|||||||
max_model_len=args.max_model_len,
|
max_model_len=args.max_model_len,
|
||||||
input_len=args.input_len,
|
input_len=args.input_len,
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
num_gpu_blocks=args.num_gpu_blocks,
|
||||||
|
block_size=args.block_size,
|
||||||
needle_position=args.needle_position,
|
needle_position=args.needle_position,
|
||||||
needle_value=args.needle_value,
|
needle_value=args.needle_value,
|
||||||
max_new_tokens=args.max_new_tokens,
|
max_new_tokens=args.max_new_tokens,
|
||||||
enable_cpu_offload=args.enable_offload,
|
enable_cpu_offload=args.enable_offload,
|
||||||
|
enable_quest=args.enable_quest,
|
||||||
|
sparse_topk=args.sparse_topk,
|
||||||
|
sparse_threshold=args.sparse_threshold,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -8,148 +8,9 @@ Uses standard HuggingFace inference (no custom KV cache, no offload).
|
|||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
from modeling_qwen3 import Qwen3ForCausalLM
|
||||||
|
from utils import generate_needle_prompt, check_needle_answer
|
||||||
# ============================================================
|
|
||||||
# Needle Test Generator
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def generate_needle_prompt(
|
|
||||||
tokenizer,
|
|
||||||
target_length: int,
|
|
||||||
needle_position: float = 0.5,
|
|
||||||
needle_value: str = "7492",
|
|
||||||
use_chat_template: bool = True,
|
|
||||||
) -> tuple[str, str]:
|
|
||||||
"""
|
|
||||||
Generate a needle-in-haystack prompt of approximately target_length tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tokenizer: HuggingFace tokenizer for length estimation
|
|
||||||
target_length: Target total sequence length in tokens
|
|
||||||
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
|
|
||||||
needle_value: The secret value to hide in the haystack
|
|
||||||
use_chat_template: Whether to use chat template for instruct models
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(prompt, expected_answer): The full prompt and the expected needle value
|
|
||||||
"""
|
|
||||||
# Haystack filler paragraphs (various topics to create realistic context)
|
|
||||||
haystack_paragraphs = [
|
|
||||||
"The weather today is quite pleasant with clear skies and moderate temperatures. "
|
|
||||||
"Many people are enjoying outdoor activities in the park. "
|
|
||||||
"Birds are singing in the trees and children are playing on the swings. ",
|
|
||||||
|
|
||||||
"In the world of technology, new innovations continue to emerge every day. "
|
|
||||||
"Researchers are working on advanced algorithms and computing systems. "
|
|
||||||
"The future of artificial intelligence looks promising with many breakthroughs. ",
|
|
||||||
|
|
||||||
"The history of human civilization spans thousands of years. "
|
|
||||||
"Ancient cultures developed writing, mathematics, and astronomy. "
|
|
||||||
"Trade routes connected distant lands and facilitated cultural exchange. ",
|
|
||||||
|
|
||||||
"Modern cooking combines traditional techniques with new ingredients. "
|
|
||||||
"Chefs around the world experiment with flavors and presentations. "
|
|
||||||
"Food brings people together and creates memorable experiences. ",
|
|
||||||
|
|
||||||
"The ocean covers more than seventy percent of Earth's surface. "
|
|
||||||
"Marine ecosystems support an incredible diversity of life forms. "
|
|
||||||
"Scientists continue to discover new species in the deep sea. ",
|
|
||||||
|
|
||||||
"Music has been a part of human culture since prehistoric times. "
|
|
||||||
"Different genres evolved across various regions and time periods. "
|
|
||||||
"Today, people can access millions of songs through digital platforms. ",
|
|
||||||
|
|
||||||
"Space exploration has revealed many secrets about our universe. "
|
|
||||||
"Telescopes can observe galaxies billions of light years away. "
|
|
||||||
"Future missions aim to establish human presence on other planets. ",
|
|
||||||
|
|
||||||
"The study of languages reveals patterns in human cognition. "
|
|
||||||
"Linguists analyze grammar, semantics, and phonetics across cultures. "
|
|
||||||
"Language continues to evolve with new words and expressions. ",
|
|
||||||
]
|
|
||||||
|
|
||||||
# The needle sentence
|
|
||||||
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
|
||||||
|
|
||||||
# Estimate tokens for fixed parts
|
|
||||||
needle_tokens = len(tokenizer.encode(needle, add_special_tokens=False))
|
|
||||||
question_text = "What is the secret number mentioned in the text above? Answer with just the number."
|
|
||||||
question_tokens = len(tokenizer.encode(question_text, add_special_tokens=False))
|
|
||||||
# Buffer for chat template, special tokens, etc.
|
|
||||||
overhead_tokens = 100 if use_chat_template else 50
|
|
||||||
|
|
||||||
# Available tokens for haystack
|
|
||||||
haystack_target_tokens = target_length - needle_tokens - question_tokens - overhead_tokens
|
|
||||||
if haystack_target_tokens < 100:
|
|
||||||
raise ValueError(f"target_length {target_length} is too short for needle test")
|
|
||||||
|
|
||||||
# Build haystack by repeating paragraphs
|
|
||||||
haystack_parts = []
|
|
||||||
current_tokens = 0
|
|
||||||
para_idx = 0
|
|
||||||
|
|
||||||
while current_tokens < haystack_target_tokens:
|
|
||||||
para = haystack_paragraphs[para_idx % len(haystack_paragraphs)]
|
|
||||||
para_tokens = len(tokenizer.encode(para, add_special_tokens=False))
|
|
||||||
if current_tokens + para_tokens > haystack_target_tokens:
|
|
||||||
break
|
|
||||||
haystack_parts.append(para)
|
|
||||||
current_tokens += para_tokens
|
|
||||||
para_idx += 1
|
|
||||||
|
|
||||||
# Calculate needle insertion point
|
|
||||||
needle_idx = int(len(haystack_parts) * needle_position)
|
|
||||||
needle_idx = max(0, min(needle_idx, len(haystack_parts)))
|
|
||||||
|
|
||||||
# Insert needle
|
|
||||||
haystack_parts.insert(needle_idx, needle)
|
|
||||||
|
|
||||||
# Assemble prompt
|
|
||||||
full_text = "".join(haystack_parts)
|
|
||||||
|
|
||||||
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
|
||||||
# Use chat template for instruct models
|
|
||||||
# For Qwen3, add /no_think to disable thinking mode
|
|
||||||
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": f"{full_text}\n\n{question_text}"}
|
|
||||||
]
|
|
||||||
prompt = tokenizer.apply_chat_template(
|
|
||||||
messages,
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Raw text format for base models
|
|
||||||
question = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
|
||||||
prompt = full_text + question
|
|
||||||
|
|
||||||
# Verify length
|
|
||||||
actual_tokens = len(tokenizer.encode(prompt, add_special_tokens=False))
|
|
||||||
print(f"[NeedleTest] Target: {target_length} tokens, Actual: {actual_tokens} tokens")
|
|
||||||
print(f"[NeedleTest] Needle position: {needle_position:.0%} ({needle_idx}/{len(haystack_parts)-1} paragraphs)")
|
|
||||||
print(f"[NeedleTest] Using chat template: {use_chat_template and hasattr(tokenizer, 'apply_chat_template')}")
|
|
||||||
|
|
||||||
return prompt, needle_value
|
|
||||||
|
|
||||||
|
|
||||||
def check_needle_answer(output_text: str, expected: str) -> bool:
|
|
||||||
"""Check if the model output contains the expected needle value."""
|
|
||||||
import re
|
|
||||||
# Clean output - remove special tokens and whitespace
|
|
||||||
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
|
||||||
output_clean = ' '.join(output_clean.split()).lower()
|
|
||||||
expected_clean = expected.strip().lower()
|
|
||||||
|
|
||||||
# Check if expected value appears in output
|
|
||||||
if expected_clean in output_clean:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Try to extract numbers and check if expected is among them
|
|
||||||
numbers = re.findall(r'\d+', output_clean)
|
|
||||||
return expected_clean in numbers
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@@ -207,22 +68,19 @@ def run_needle_test(
|
|||||||
# 3. Load model
|
# 3. Load model
|
||||||
print("[3/4] Loading model...")
|
print("[3/4] Loading model...")
|
||||||
torch_dtype = {
|
torch_dtype = {
|
||||||
"auto": "auto",
|
"auto": torch.float16, # default to float16 for custom model
|
||||||
"float16": torch.float16,
|
"float16": torch.float16,
|
||||||
"bfloat16": torch.bfloat16,
|
"bfloat16": torch.bfloat16,
|
||||||
}.get(dtype, "auto")
|
}.get(dtype, torch.float16)
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = Qwen3ForCausalLM.from_pretrained(model_path, dtype=torch_dtype)
|
||||||
model_path,
|
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
device_map="auto",
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
# 4. Generate output
|
# 4. Generate output
|
||||||
print("[4/4] Running inference...")
|
print("[4/4] Running inference...")
|
||||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
|
device = next(model.parameters()).device
|
||||||
|
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
||||||
print(f" Input shape: {input_ids.shape}")
|
print(f" Input shape: {input_ids.shape}")
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
@@ -1,695 +0,0 @@
|
|||||||
"""
|
|
||||||
Test script to verify CPU offload correctness using distinctive KV patterns.
|
|
||||||
|
|
||||||
Strategy:
|
|
||||||
1. Hook into attention forward pass
|
|
||||||
2. Overwrite K/V with distinctive patterns based on chunk_idx (e.g., K=chunk_idx, V=-chunk_idx)
|
|
||||||
3. After offload to CPU, verify CPU cache contains correct patterns
|
|
||||||
4. On subsequent chunks, verify loaded KV from CPU has correct patterns
|
|
||||||
|
|
||||||
This catches bugs like:
|
|
||||||
- Wrong block being offloaded
|
|
||||||
- Wrong block being loaded
|
|
||||||
- Data corruption during transfer
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from random import randint, seed
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from nanovllm.utils.context import get_context
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
|
||||||
MAX_MODEL_LEN = 64 * 1024
|
|
||||||
NUM_GPU_BLOCKS = 4
|
|
||||||
INPUT_LEN = 32 * 1024 # 32K tokens = 32 chunks (fits in 40 CPU blocks)
|
|
||||||
BLOCK_SIZE = 1024
|
|
||||||
|
|
||||||
# Test state
|
|
||||||
errors = []
|
|
||||||
chunk_patterns = {} # chunk_idx -> (k_pattern, v_pattern)
|
|
||||||
block_coverage = {} # chunk_idx -> set of blocks that were actually computed
|
|
||||||
load_operations = [] # List of (chunk_idx, slot_id, cpu_block_id, k_ok, v_ok) tuples
|
|
||||||
current_chunk_for_load = [0] # Mutable container to track current chunk during loads
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Pattern Helpers
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def get_expected_pattern(chunk_idx: int):
|
|
||||||
"""Get expected K/V pattern for a chunk."""
|
|
||||||
# Use float values that are easy to identify
|
|
||||||
k_val = float(chunk_idx + 1) # 1.0, 2.0, 3.0, ...
|
|
||||||
v_val = float(-(chunk_idx + 1)) # -1.0, -2.0, -3.0, ...
|
|
||||||
return k_val, v_val
|
|
||||||
|
|
||||||
|
|
||||||
def fill_with_pattern(tensor: torch.Tensor, value: float):
|
|
||||||
"""Fill tensor with a constant value."""
|
|
||||||
tensor.fill_(value)
|
|
||||||
|
|
||||||
|
|
||||||
def check_pattern(tensor: torch.Tensor, expected: float, name: str, tolerance: float = 1e-3):
|
|
||||||
"""Check if tensor contains expected pattern."""
|
|
||||||
actual_mean = tensor.float().mean().item()
|
|
||||||
if abs(actual_mean - expected) > tolerance:
|
|
||||||
return False, f"{name}: expected mean={expected}, got {actual_mean}"
|
|
||||||
return True, None
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Load Verification Instrumentation
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
_original_load_to_slot_layer = None
|
|
||||||
_offload_engine_ref = None
|
|
||||||
|
|
||||||
def make_verified_load_to_slot_layer(original_func, offload_engine):
|
|
||||||
"""
|
|
||||||
Create a wrapper around load_to_slot_layer that verifies each load operation.
|
|
||||||
|
|
||||||
After each H2D transfer, checks that the GPU slot contains the expected
|
|
||||||
pattern from the source CPU block.
|
|
||||||
"""
|
|
||||||
def verified_load(slot_idx: int, layer_id: int, cpu_block_id: int):
|
|
||||||
# Call original load
|
|
||||||
original_func(slot_idx, layer_id, cpu_block_id)
|
|
||||||
|
|
||||||
# Only verify layer 0 to reduce overhead
|
|
||||||
if layer_id != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# IMPORTANT: Synchronize CUDA to ensure async transfer is complete
|
|
||||||
# The transfer happens on a per-slot stream, and wait_slot_layer only
|
|
||||||
# makes compute_stream wait. We need full sync to read on default stream.
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Get the expected pattern for this CPU block
|
|
||||||
# cpu_block_id == chunk_idx in our sequential test
|
|
||||||
expected_k, expected_v = get_expected_pattern(cpu_block_id)
|
|
||||||
|
|
||||||
# Read GPU slot data (GPU cache has no layer dimension)
|
|
||||||
gpu_k = offload_engine.k_cache_gpu[slot_idx]
|
|
||||||
gpu_v = offload_engine.v_cache_gpu[slot_idx]
|
|
||||||
|
|
||||||
actual_k = gpu_k.float().mean().item()
|
|
||||||
actual_v = gpu_v.float().mean().item()
|
|
||||||
|
|
||||||
k_ok = abs(actual_k - expected_k) < 1e-3
|
|
||||||
v_ok = abs(actual_v - expected_v) < 1e-3
|
|
||||||
|
|
||||||
chunk_idx = current_chunk_for_load[0]
|
|
||||||
load_operations.append({
|
|
||||||
'chunk_idx': chunk_idx,
|
|
||||||
'slot_idx': slot_idx,
|
|
||||||
'cpu_block_id': cpu_block_id,
|
|
||||||
'expected_k': expected_k,
|
|
||||||
'expected_v': expected_v,
|
|
||||||
'actual_k': actual_k,
|
|
||||||
'actual_v': actual_v,
|
|
||||||
'k_ok': k_ok,
|
|
||||||
'v_ok': v_ok,
|
|
||||||
})
|
|
||||||
|
|
||||||
if not (k_ok and v_ok):
|
|
||||||
errors.append(f"Load verification failed: chunk {chunk_idx}, "
|
|
||||||
f"CPU block {cpu_block_id} -> GPU slot {slot_idx}: "
|
|
||||||
f"expected K={expected_k:.1f}/V={expected_v:.1f}, "
|
|
||||||
f"got K={actual_k:.4f}/V={actual_v:.4f}")
|
|
||||||
|
|
||||||
return verified_load
|
|
||||||
|
|
||||||
|
|
||||||
def install_load_verification(llm):
|
|
||||||
"""Install verification wrapper on load_to_slot_layer."""
|
|
||||||
global _original_load_to_slot_layer, _offload_engine_ref
|
|
||||||
|
|
||||||
oe = llm.model_runner.kvcache_manager.offload_engine
|
|
||||||
_offload_engine_ref = oe
|
|
||||||
_original_load_to_slot_layer = oe.load_to_slot_layer
|
|
||||||
|
|
||||||
oe.load_to_slot_layer = make_verified_load_to_slot_layer(
|
|
||||||
_original_load_to_slot_layer, oe
|
|
||||||
)
|
|
||||||
print("Installed load verification wrapper on load_to_slot_layer")
|
|
||||||
|
|
||||||
|
|
||||||
def uninstall_load_verification():
|
|
||||||
"""Restore original load_to_slot_layer."""
|
|
||||||
global _original_load_to_slot_layer, _offload_engine_ref
|
|
||||||
|
|
||||||
if _offload_engine_ref is not None and _original_load_to_slot_layer is not None:
|
|
||||||
_offload_engine_ref.load_to_slot_layer = _original_load_to_slot_layer
|
|
||||||
print("Restored original load_to_slot_layer")
|
|
||||||
|
|
||||||
_original_load_to_slot_layer = None
|
|
||||||
_offload_engine_ref = None
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Attention Hook
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def make_kv_pattern_pre_hook(layer_id: int):
|
|
||||||
"""
|
|
||||||
Create a PRE-forward hook that overwrites K/V with distinctive patterns BEFORE
|
|
||||||
they are stored to cache. This is called before attention.forward().
|
|
||||||
|
|
||||||
register_forward_pre_hook receives (module, inputs) and can modify inputs in-place.
|
|
||||||
"""
|
|
||||||
def hook(module, inputs):
|
|
||||||
ctx = get_context()
|
|
||||||
if not ctx.is_prefill:
|
|
||||||
return
|
|
||||||
|
|
||||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
||||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
|
||||||
|
|
||||||
if kvcache_manager is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only process layer 0 for cleaner output
|
|
||||||
if layer_id != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
q, k, v = inputs
|
|
||||||
k_pattern, v_pattern = get_expected_pattern(chunk_idx)
|
|
||||||
|
|
||||||
# === Overwrite current chunk's K/V with distinctive pattern ===
|
|
||||||
# This happens BEFORE forward(), so these values will be stored to cache
|
|
||||||
k.fill_(k_pattern)
|
|
||||||
v.fill_(v_pattern)
|
|
||||||
|
|
||||||
# Only print for first few and last few chunks to reduce noise
|
|
||||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
|
||||||
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
|
|
||||||
print(f"[Chunk {chunk_idx:3d}] Set K={k_pattern:.1f}, V={v_pattern:.1f}")
|
|
||||||
elif chunk_idx == 3:
|
|
||||||
print(f"... (chunks 3 to {num_chunks - 3} omitted) ...")
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def make_block_coverage_pre_hook(layer_id: int):
|
|
||||||
"""
|
|
||||||
Create a PRE-forward hook to verify that all previous blocks are included
|
|
||||||
in the cpu_block_table for chunked prefill attention.
|
|
||||||
|
|
||||||
This catches bugs where:
|
|
||||||
- Some blocks are missing from the computation
|
|
||||||
- Sparse policy incorrectly filters out blocks (when not intended)
|
|
||||||
- Block table construction has off-by-one errors
|
|
||||||
"""
|
|
||||||
def hook(module, inputs):
|
|
||||||
ctx = get_context()
|
|
||||||
if not ctx.is_prefill:
|
|
||||||
return
|
|
||||||
|
|
||||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
||||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
|
||||||
|
|
||||||
if kvcache_manager is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only process layer 0 for cleaner output
|
|
||||||
if layer_id != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Update current chunk for load verification tracking
|
|
||||||
current_chunk_for_load[0] = chunk_idx
|
|
||||||
|
|
||||||
# No previous blocks for chunk 0
|
|
||||||
if chunk_idx == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get the sequence and its block table (same logic as _chunked_prefill_attention)
|
|
||||||
seq = ctx.chunked_seq if hasattr(ctx, 'chunked_seq') else None
|
|
||||||
if seq is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Get the CPU block table that will be used for attention
|
|
||||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
|
||||||
|
|
||||||
# Expected blocks: 0 to chunk_idx-1 (all previous chunks)
|
|
||||||
expected_blocks = set(range(chunk_idx))
|
|
||||||
actual_blocks = set(cpu_block_table) if cpu_block_table else set()
|
|
||||||
|
|
||||||
# Store for later summary
|
|
||||||
block_coverage[chunk_idx] = {
|
|
||||||
'expected': expected_blocks,
|
|
||||||
'actual': actual_blocks,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check for missing blocks
|
|
||||||
missing_blocks = expected_blocks - actual_blocks
|
|
||||||
extra_blocks = actual_blocks - expected_blocks
|
|
||||||
|
|
||||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
|
||||||
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or missing_blocks:
|
|
||||||
if not missing_blocks and not extra_blocks:
|
|
||||||
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [OK]")
|
|
||||||
else:
|
|
||||||
status_parts = []
|
|
||||||
if missing_blocks:
|
|
||||||
status_parts.append(f"MISSING {sorted(missing_blocks)}")
|
|
||||||
if extra_blocks:
|
|
||||||
status_parts.append(f"EXTRA {sorted(extra_blocks)}")
|
|
||||||
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [{', '.join(status_parts)}]")
|
|
||||||
elif chunk_idx == 4:
|
|
||||||
# Indicate that middle chunks are being verified silently
|
|
||||||
print(f" ... (verifying chunks 4-{num_chunks - 3} silently) ...")
|
|
||||||
|
|
||||||
if missing_blocks:
|
|
||||||
errors.append(f"Chunk {chunk_idx} missing blocks: {sorted(missing_blocks)}")
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def make_gpu_write_verification_post_hook(layer_id: int):
|
|
||||||
"""
|
|
||||||
Create a POST-forward hook to verify the current chunk's KV was correctly
|
|
||||||
written to the GPU ring buffer write_slot.
|
|
||||||
|
|
||||||
This is a more reliable verification than checking load slots, because:
|
|
||||||
1. Post-hook runs AFTER forward() writes to GPU cache
|
|
||||||
2. write_slot mapping is deterministic: chunk_idx % num_ring_slots
|
|
||||||
3. We injected known patterns in pre-hook, now verify they're in GPU cache
|
|
||||||
"""
|
|
||||||
def hook(module, inputs, output):
|
|
||||||
ctx = get_context()
|
|
||||||
if not ctx.is_prefill:
|
|
||||||
return
|
|
||||||
|
|
||||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
||||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
|
||||||
|
|
||||||
if kvcache_manager is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only process layer 0 for cleaner output
|
|
||||||
if layer_id != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
oe = kvcache_manager.offload_engine
|
|
||||||
num_ring_slots = oe.num_ring_slots
|
|
||||||
write_slot = chunk_idx % num_ring_slots
|
|
||||||
|
|
||||||
# Get expected pattern for current chunk
|
|
||||||
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
|
||||||
|
|
||||||
# Verify write_slot contains current chunk's data (GPU cache has no layer dimension)
|
|
||||||
gpu_k = oe.k_cache_gpu[write_slot]
|
|
||||||
gpu_v = oe.v_cache_gpu[write_slot]
|
|
||||||
|
|
||||||
actual_k_mean = gpu_k.float().mean().item()
|
|
||||||
actual_v_mean = gpu_v.float().mean().item()
|
|
||||||
|
|
||||||
k_ok, _ = check_pattern(gpu_k, expected_k, f"GPU slot {write_slot}")
|
|
||||||
v_ok, _ = check_pattern(gpu_v, expected_v, f"GPU slot {write_slot}")
|
|
||||||
|
|
||||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
|
||||||
# Print for first/last chunks, or if there's an error
|
|
||||||
if True or chunk_idx >= num_chunks - 2 or not (k_ok and v_ok):
|
|
||||||
if k_ok and v_ok:
|
|
||||||
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: K={expected_k:.1f}, V={expected_v:.1f} [OK]")
|
|
||||||
else:
|
|
||||||
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: expected K={expected_k:.1f}/V={expected_v:.1f}, "
|
|
||||||
f"got K={actual_k_mean:.2f}/V={actual_v_mean:.2f} [FAIL]")
|
|
||||||
elif chunk_idx == 4:
|
|
||||||
print(f" ... (GPU write verification for chunks 4-{num_chunks - 3} silently) ...")
|
|
||||||
|
|
||||||
if not (k_ok and v_ok):
|
|
||||||
errors.append(f"GPU write_slot {write_slot} at chunk {chunk_idx}: "
|
|
||||||
f"expected K={expected_k}, V={expected_v}, got K={actual_k_mean:.4f}, V={actual_v_mean:.4f}")
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def make_kv_verification_post_hook(layer_id: int):
|
|
||||||
"""
|
|
||||||
Create a POST-forward hook to verify CPU cache contains correct patterns
|
|
||||||
from previously offloaded blocks.
|
|
||||||
"""
|
|
||||||
def hook(module, inputs, output):
|
|
||||||
ctx = get_context()
|
|
||||||
if not ctx.is_prefill:
|
|
||||||
return
|
|
||||||
|
|
||||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
||||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
|
||||||
|
|
||||||
if kvcache_manager is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Only process layer 0 for cleaner output
|
|
||||||
if layer_id != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# === Verify previously offloaded blocks in CPU cache ===
|
|
||||||
if chunk_idx >= 1:
|
|
||||||
oe = kvcache_manager.offload_engine
|
|
||||||
num_ok = 0
|
|
||||||
num_fail = 0
|
|
||||||
|
|
||||||
# Check all previously offloaded blocks
|
|
||||||
for prev_chunk in range(chunk_idx):
|
|
||||||
# CPU block ID = prev_chunk (in simple sequential case)
|
|
||||||
cpu_block_id = prev_chunk
|
|
||||||
|
|
||||||
# Get expected pattern for this block
|
|
||||||
expected_k, expected_v = get_expected_pattern(prev_chunk)
|
|
||||||
|
|
||||||
# Read from CPU cache (layer 0)
|
|
||||||
cpu_k = oe.k_cache_cpu[layer_id, cpu_block_id]
|
|
||||||
cpu_v = oe.v_cache_cpu[layer_id, cpu_block_id]
|
|
||||||
|
|
||||||
# Verify patterns
|
|
||||||
k_ok, k_err = check_pattern(cpu_k, expected_k, f"CPU K block {cpu_block_id}")
|
|
||||||
v_ok, v_err = check_pattern(cpu_v, expected_v, f"CPU V block {cpu_block_id}")
|
|
||||||
|
|
||||||
if k_ok and v_ok:
|
|
||||||
num_ok += 1
|
|
||||||
else:
|
|
||||||
num_fail += 1
|
|
||||||
if k_err:
|
|
||||||
errors.append(k_err)
|
|
||||||
if v_err:
|
|
||||||
errors.append(v_err)
|
|
||||||
|
|
||||||
# Only print summary for each chunk verification
|
|
||||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
|
||||||
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or num_fail > 0:
|
|
||||||
status = "OK" if num_fail == 0 else f"FAIL({num_fail})"
|
|
||||||
print(f" CPU verify chunk {chunk_idx:2d}: {num_ok} blocks OK [{status}]")
|
|
||||||
elif chunk_idx == 4:
|
|
||||||
print(f" ... (CPU cache verification for chunks 4-{num_chunks - 3} silently) ...")
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def make_post_chunk_verification_hook(layer_id: int):
|
|
||||||
"""
|
|
||||||
Post-forward hook to verify GPU ring buffer state after attention.
|
|
||||||
"""
|
|
||||||
def hook(module, inputs, output):
|
|
||||||
ctx = get_context()
|
|
||||||
if not ctx.is_prefill or layer_id != 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
|
|
||||||
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
|
|
||||||
|
|
||||||
if kvcache_manager is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
oe = kvcache_manager.offload_engine
|
|
||||||
|
|
||||||
# After attention, the current chunk's KV should be in the GPU ring buffer
|
|
||||||
# Ring slot = chunk_idx % num_ring_slots
|
|
||||||
ring_slot = chunk_idx % oe.num_ring_slots
|
|
||||||
|
|
||||||
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
|
||||||
|
|
||||||
# Check GPU ring buffer (GPU cache has no layer dimension)
|
|
||||||
gpu_k = oe.k_cache_gpu[ring_slot]
|
|
||||||
gpu_v = oe.v_cache_gpu[ring_slot]
|
|
||||||
|
|
||||||
k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}")
|
|
||||||
v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}")
|
|
||||||
|
|
||||||
if k_ok and v_ok:
|
|
||||||
print(f" [OK] GPU slot {ring_slot} (chunk {chunk_idx}): K={expected_k}, V={expected_v}")
|
|
||||||
else:
|
|
||||||
if k_err:
|
|
||||||
print(f" [FAIL] {k_err}")
|
|
||||||
errors.append(k_err)
|
|
||||||
if v_err:
|
|
||||||
print(f" [FAIL] {v_err}")
|
|
||||||
errors.append(v_err)
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
|
|
||||||
def register_hooks(llm):
|
|
||||||
"""Register pre and post forward hooks."""
|
|
||||||
hooks = []
|
|
||||||
model = llm.model_runner.model
|
|
||||||
|
|
||||||
for layer_idx, decoder_layer in enumerate(model.model.layers):
|
|
||||||
attn_module = decoder_layer.self_attn.attn
|
|
||||||
|
|
||||||
# PRE-forward hook 1: Verify all previous blocks are in cpu_block_table
|
|
||||||
coverage_hook = attn_module.register_forward_pre_hook(make_block_coverage_pre_hook(layer_idx))
|
|
||||||
hooks.append(coverage_hook)
|
|
||||||
|
|
||||||
# PRE-forward hook 2: Inject K/V patterns before they're stored to cache
|
|
||||||
pattern_hook = attn_module.register_forward_pre_hook(make_kv_pattern_pre_hook(layer_idx))
|
|
||||||
hooks.append(pattern_hook)
|
|
||||||
|
|
||||||
# POST-forward hook 1: Verify GPU write_slot contains current chunk's data
|
|
||||||
gpu_verify_hook = attn_module.register_forward_hook(make_gpu_write_verification_post_hook(layer_idx))
|
|
||||||
hooks.append(gpu_verify_hook)
|
|
||||||
|
|
||||||
# POST-forward hook 2: Verify CPU cache contains correct patterns after offload
|
|
||||||
cpu_verify_hook = attn_module.register_forward_hook(make_kv_verification_post_hook(layer_idx))
|
|
||||||
hooks.append(cpu_verify_hook)
|
|
||||||
|
|
||||||
return hooks
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Final Verification
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def verify_final_cpu_state(llm, num_chunks: int):
|
|
||||||
"""Verify all CPU blocks have correct patterns after prefill completes."""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Final CPU Cache Verification")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
kvcache_manager = llm.model_runner.kvcache_manager
|
|
||||||
oe = kvcache_manager.offload_engine
|
|
||||||
|
|
||||||
num_ok = 0
|
|
||||||
num_fail = 0
|
|
||||||
fail_details = []
|
|
||||||
|
|
||||||
# After prefill, all chunks should be in CPU
|
|
||||||
for chunk_idx in range(num_chunks):
|
|
||||||
cpu_block_id = chunk_idx
|
|
||||||
expected_k, expected_v = get_expected_pattern(chunk_idx)
|
|
||||||
|
|
||||||
# Check layer 0
|
|
||||||
cpu_k = oe.k_cache_cpu[0, cpu_block_id]
|
|
||||||
cpu_v = oe.v_cache_cpu[0, cpu_block_id]
|
|
||||||
|
|
||||||
k_ok, k_err = check_pattern(cpu_k, expected_k, f"Final CPU K block {cpu_block_id}")
|
|
||||||
v_ok, v_err = check_pattern(cpu_v, expected_v, f"Final CPU V block {cpu_block_id}")
|
|
||||||
|
|
||||||
if k_ok and v_ok:
|
|
||||||
num_ok += 1
|
|
||||||
# Only print first few and last few
|
|
||||||
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
|
|
||||||
actual_k_mean = cpu_k.float().mean().item()
|
|
||||||
actual_v_mean = cpu_v.float().mean().item()
|
|
||||||
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
|
|
||||||
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [OK]")
|
|
||||||
elif chunk_idx == 3:
|
|
||||||
print(f" ... (blocks 3 to {num_chunks - 3} verified OK) ...")
|
|
||||||
else:
|
|
||||||
num_fail += 1
|
|
||||||
actual_k_mean = cpu_k.float().mean().item()
|
|
||||||
actual_v_mean = cpu_v.float().mean().item()
|
|
||||||
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
|
|
||||||
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [FAIL]")
|
|
||||||
if k_err:
|
|
||||||
errors.append(k_err)
|
|
||||||
if v_err:
|
|
||||||
errors.append(v_err)
|
|
||||||
|
|
||||||
print(f"\nTotal: {num_ok} OK, {num_fail} FAIL out of {num_chunks} blocks")
|
|
||||||
|
|
||||||
|
|
||||||
def verify_block_coverage_summary(num_chunks: int):
|
|
||||||
"""Verify that all chunks had complete block coverage during prefill."""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Block Coverage Verification Summary")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
num_ok = 0
|
|
||||||
num_fail = 0
|
|
||||||
total_blocks_expected = 0
|
|
||||||
total_blocks_computed = 0
|
|
||||||
|
|
||||||
for chunk_idx in range(1, num_chunks): # Start from 1 (chunk 0 has no previous)
|
|
||||||
if chunk_idx not in block_coverage:
|
|
||||||
print(f" Chunk {chunk_idx}: NO COVERAGE DATA [FAIL]")
|
|
||||||
errors.append(f"Chunk {chunk_idx} has no block coverage data")
|
|
||||||
num_fail += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
coverage = block_coverage[chunk_idx]
|
|
||||||
expected = coverage['expected']
|
|
||||||
actual = coverage['actual']
|
|
||||||
missing = expected - actual
|
|
||||||
|
|
||||||
total_blocks_expected += len(expected)
|
|
||||||
total_blocks_computed += len(actual)
|
|
||||||
|
|
||||||
if not missing:
|
|
||||||
num_ok += 1
|
|
||||||
else:
|
|
||||||
num_fail += 1
|
|
||||||
|
|
||||||
# Print summary
|
|
||||||
if num_fail == 0:
|
|
||||||
print(f" All {num_ok} chunks had complete block coverage [OK]")
|
|
||||||
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
|
|
||||||
else:
|
|
||||||
print(f" {num_ok} chunks OK, {num_fail} chunks with missing blocks [FAIL]")
|
|
||||||
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
|
|
||||||
|
|
||||||
# Verify the total is correct: sum of 0+1+2+...+(n-1) = n*(n-1)/2
|
|
||||||
expected_total = num_chunks * (num_chunks - 1) // 2
|
|
||||||
if total_blocks_expected == expected_total:
|
|
||||||
print(f" Expected total blocks matches formula: {expected_total} [OK]")
|
|
||||||
else:
|
|
||||||
print(f" Expected total mismatch: got {total_blocks_expected}, formula gives {expected_total} [FAIL]")
|
|
||||||
errors.append(f"Block coverage total mismatch")
|
|
||||||
|
|
||||||
|
|
||||||
def verify_load_operations_summary(num_chunks: int):
|
|
||||||
"""Verify all H2D load operations transferred correct data."""
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("H2D Load Operations Verification Summary")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
if not load_operations:
|
|
||||||
print(" WARNING: No load operations recorded!")
|
|
||||||
print(" (This may indicate load verification was not installed)")
|
|
||||||
return
|
|
||||||
|
|
||||||
num_ok = 0
|
|
||||||
num_fail = 0
|
|
||||||
loads_per_chunk = {}
|
|
||||||
|
|
||||||
for op in load_operations:
|
|
||||||
chunk_idx = op['chunk_idx']
|
|
||||||
if chunk_idx not in loads_per_chunk:
|
|
||||||
loads_per_chunk[chunk_idx] = []
|
|
||||||
loads_per_chunk[chunk_idx].append(op)
|
|
||||||
|
|
||||||
if op['k_ok'] and op['v_ok']:
|
|
||||||
num_ok += 1
|
|
||||||
else:
|
|
||||||
num_fail += 1
|
|
||||||
|
|
||||||
# Print per-chunk summary for first/last chunks
|
|
||||||
for chunk_idx in sorted(loads_per_chunk.keys()):
|
|
||||||
ops = loads_per_chunk[chunk_idx]
|
|
||||||
chunk_ok = sum(1 for op in ops if op['k_ok'] and op['v_ok'])
|
|
||||||
chunk_fail = len(ops) - chunk_ok
|
|
||||||
|
|
||||||
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or chunk_fail > 0:
|
|
||||||
# Show loaded block IDs in order
|
|
||||||
block_ids = [op['cpu_block_id'] for op in ops]
|
|
||||||
if chunk_fail == 0:
|
|
||||||
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks {block_ids} [OK]")
|
|
||||||
else:
|
|
||||||
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks, {chunk_fail} FAILED [FAIL]")
|
|
||||||
for op in ops:
|
|
||||||
if not (op['k_ok'] and op['v_ok']):
|
|
||||||
print(f" CPU block {op['cpu_block_id']} -> slot {op['slot_idx']}: "
|
|
||||||
f"expected K={op['expected_k']:.1f}/V={op['expected_v']:.1f}, "
|
|
||||||
f"got K={op['actual_k']:.4f}/V={op['actual_v']:.4f}")
|
|
||||||
elif chunk_idx == 4:
|
|
||||||
print(f" ... (chunks 4-{num_chunks - 3} load verification running silently) ...")
|
|
||||||
|
|
||||||
# Print overall summary
|
|
||||||
print(f"\n Total load operations: {len(load_operations)}")
|
|
||||||
print(f" Successful: {num_ok}, Failed: {num_fail}")
|
|
||||||
|
|
||||||
if num_fail == 0:
|
|
||||||
print(f" All H2D transfers verified correct [OK]")
|
|
||||||
else:
|
|
||||||
print(f" {num_fail} H2D transfers had incorrect data [FAIL]")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test Script
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("=" * 60)
|
|
||||||
print("Test: CPU Offload Correctness with Distinctive KV Patterns")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Input: {INPUT_LEN} tokens, {INPUT_LEN // BLOCK_SIZE} chunks")
|
|
||||||
print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}")
|
|
||||||
print(f"Pattern: K=chunk_idx+1, V=-(chunk_idx+1)")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# 1. Initialize LLM
|
|
||||||
print("Initializing LLM...")
|
|
||||||
llm = LLM(
|
|
||||||
MODEL_PATH,
|
|
||||||
enforce_eager=True,
|
|
||||||
max_model_len=MAX_MODEL_LEN,
|
|
||||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
kvcache_block_size=BLOCK_SIZE,
|
|
||||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
|
||||||
dtype="float16",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Register hooks
|
|
||||||
hooks = register_hooks(llm)
|
|
||||||
print(f"Registered {len(hooks)} hooks")
|
|
||||||
|
|
||||||
# 3. Install load verification (instrument load_to_slot_layer)
|
|
||||||
install_load_verification(llm)
|
|
||||||
|
|
||||||
# 4. Generate prompt
|
|
||||||
seed(42)
|
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
|
||||||
num_chunks = INPUT_LEN // BLOCK_SIZE
|
|
||||||
|
|
||||||
# 5. Run prefill
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("Running Prefill with KV Pattern Injection...")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
|
||||||
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
|
||||||
|
|
||||||
# 6. Remove hooks and uninstall load verification
|
|
||||||
for hook in hooks:
|
|
||||||
hook.remove()
|
|
||||||
uninstall_load_verification()
|
|
||||||
|
|
||||||
# 7. Final verification
|
|
||||||
verify_final_cpu_state(llm, num_chunks)
|
|
||||||
|
|
||||||
# 8. Block coverage summary
|
|
||||||
verify_block_coverage_summary(num_chunks)
|
|
||||||
|
|
||||||
# 9. H2D load operations summary
|
|
||||||
verify_load_operations_summary(num_chunks)
|
|
||||||
|
|
||||||
# 10. Report results
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
if errors:
|
|
||||||
print(f"test_offload_correctness: FAILED ({len(errors)} errors)")
|
|
||||||
for err in errors[:10]: # Show first 10 errors
|
|
||||||
print(f" - {err}")
|
|
||||||
exit(1)
|
|
||||||
else:
|
|
||||||
print("test_offload_correctness: PASSED")
|
|
||||||
print("=" * 60)
|
|
||||||
@@ -1,119 +0,0 @@
|
|||||||
"""
|
|
||||||
Test script for OffloadEngine - CPU-GPU KV cache transfer engine.
|
|
||||||
|
|
||||||
Demonstrates: ring buffer, H2D/D2H transfers, CUDA events, KV access.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from nanovllm.kvcache.offload_engine import OffloadEngine
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Utility Functions
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def verify(tensor: torch.Tensor, expected: float, name: str) -> None:
|
|
||||||
"""Verify tensor contains expected value."""
|
|
||||||
actual = tensor.mean().item()
|
|
||||||
assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}"
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
NUM_LAYERS = 4
|
|
||||||
NUM_GPU_BLOCKS = 8
|
|
||||||
NUM_CPU_BLOCKS = 16
|
|
||||||
BLOCK_SIZE = 64
|
|
||||||
NUM_KV_HEADS = 4
|
|
||||||
HEAD_DIM = 32
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test Script
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# 1. Initialize
|
|
||||||
engine = OffloadEngine(
|
|
||||||
num_layers=NUM_LAYERS,
|
|
||||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
|
||||||
num_cpu_blocks=NUM_CPU_BLOCKS,
|
|
||||||
block_size=BLOCK_SIZE,
|
|
||||||
num_kv_heads=NUM_KV_HEADS,
|
|
||||||
head_dim=HEAD_DIM,
|
|
||||||
dtype=torch.float16,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Ring buffer slot management
|
|
||||||
for chunk_idx in range(12):
|
|
||||||
write_slot = engine.get_write_slot_for_prefill(chunk_idx)
|
|
||||||
load_slots = engine.get_load_slots_for_prefill(write_slot)
|
|
||||||
|
|
||||||
print("chunk idx", chunk_idx, "write slots:", write_slot, "load slots:", load_slots)
|
|
||||||
|
|
||||||
assert write_slot == chunk_idx % engine.num_ring_slots
|
|
||||||
assert write_slot not in load_slots
|
|
||||||
|
|
||||||
assert engine.decode_slot == 0
|
|
||||||
assert engine.get_load_slots_for_decode() == list(range(1, NUM_GPU_BLOCKS))
|
|
||||||
|
|
||||||
# 3. Per-slot per-layer H2D transfer
|
|
||||||
engine.k_cache_cpu[0, 0].fill_(42.0)
|
|
||||||
engine.v_cache_cpu[0, 0].fill_(42.5)
|
|
||||||
|
|
||||||
engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=0)
|
|
||||||
engine.wait_slot_layer(slot_idx=1, layer_id=0)
|
|
||||||
|
|
||||||
verify(engine.k_cache_gpu[0, 1], 42.0, "H2D K")
|
|
||||||
verify(engine.v_cache_gpu[0, 1], 42.5, "H2D V")
|
|
||||||
|
|
||||||
# 4. Compute-done event (pipeline safety)
|
|
||||||
engine.record_slot_compute_done(slot_idx=1, layer_id=0)
|
|
||||||
|
|
||||||
engine.k_cache_cpu[0, 1].fill_(100.0)
|
|
||||||
engine.v_cache_cpu[0, 1].fill_(100.5)
|
|
||||||
engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=1)
|
|
||||||
engine.wait_slot_layer(slot_idx=1, layer_id=0)
|
|
||||||
|
|
||||||
verify(engine.k_cache_gpu[0, 1], 100.0, "Reuse K")
|
|
||||||
verify(engine.v_cache_gpu[0, 1], 100.5, "Reuse V")
|
|
||||||
|
|
||||||
# 5. D2H offload
|
|
||||||
engine.k_cache_gpu[1, 2].fill_(77.0)
|
|
||||||
engine.v_cache_gpu[1, 2].fill_(77.5)
|
|
||||||
|
|
||||||
engine.offload_slot_to_cpu(slot_idx=2, cpu_block_id=5)
|
|
||||||
engine.wait_slot_offload(slot_idx=2)
|
|
||||||
|
|
||||||
verify(engine.k_cache_cpu[1, 5], 77.0, "D2H K")
|
|
||||||
verify(engine.v_cache_cpu[1, 5], 77.5, "D2H V")
|
|
||||||
|
|
||||||
# 6. KV access methods
|
|
||||||
k, v = engine.get_kv_for_slot(slot_idx=1, layer_id=0)
|
|
||||||
assert k.shape == (1, BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM)
|
|
||||||
|
|
||||||
k, v = engine.get_kv_for_slots(layer_id=0, slot_indices=[0, 1, 2])
|
|
||||||
assert k.shape == (1, 3 * BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM)
|
|
||||||
|
|
||||||
engine.k_cache_gpu[0, engine.decode_slot].fill_(33.0)
|
|
||||||
k, v = engine.get_kv_for_decode_slot_accumulated(layer_id=0, num_tokens=10)
|
|
||||||
assert k.shape == (1, 10, NUM_KV_HEADS, HEAD_DIM)
|
|
||||||
verify(k, 33.0, "Decode slot K")
|
|
||||||
|
|
||||||
# 7. Batch transfer
|
|
||||||
cpu_blocks = [2, 3, 4]
|
|
||||||
gpu_slots = [3, 4, 5]
|
|
||||||
for cpu_id in cpu_blocks:
|
|
||||||
engine.k_cache_cpu[0, cpu_id].fill_(50.0 + cpu_id)
|
|
||||||
|
|
||||||
engine.load_cpu_blocks_to_gpu_slots(layer_id=0, cpu_block_ids=cpu_blocks, gpu_slot_ids=gpu_slots)
|
|
||||||
|
|
||||||
for cpu_id, gpu_slot in zip(cpu_blocks, gpu_slots):
|
|
||||||
verify(engine.k_cache_gpu[0, gpu_slot], 50.0 + cpu_id, f"Batch slot {gpu_slot}")
|
|
||||||
|
|
||||||
# 8. Gather indices (CUDA graph compatible)
|
|
||||||
engine.update_gather_indices(layer_id=0, mappings=[(0, 0), (1, 1), (2, 2)])
|
|
||||||
assert engine.gather_indices_gpu[0, :3].tolist() == [0, 1, 2]
|
|
||||||
|
|
||||||
engine.clear_gather_indices(layer_id=0)
|
|
||||||
assert engine.gather_indices_gpu[0, 0].item() == -1
|
|
||||||
|
|
||||||
print("test_offload_engine: PASSED")
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
"""
|
|
||||||
Test script for chunked prefill with CPU offload.
|
|
||||||
|
|
||||||
Demonstrates: LLM initialization, prefill execution with CPU offload enabled.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
|
|
||||||
|
|
||||||
from random import randint, seed
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
|
|
||||||
MAX_MODEL_LEN = 32 * 1024
|
|
||||||
NUM_GPU_BLOCKS = 2
|
|
||||||
INPUT_LEN = 16 * 1024
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test Script
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# 1. Initialize LLM with CPU offload
|
|
||||||
llm = LLM(
|
|
||||||
MODEL_PATH,
|
|
||||||
enforce_eager=True,
|
|
||||||
max_model_len=MAX_MODEL_LEN,
|
|
||||||
max_num_batched_tokens=MAX_MODEL_LEN,
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
kvcache_block_size=1024,
|
|
||||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Generate random prompt tokens
|
|
||||||
seed(42)
|
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
|
|
||||||
|
|
||||||
# 3. Run prefill (max_tokens=1 to focus on prefill only)
|
|
||||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
|
||||||
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
|
|
||||||
|
|
||||||
# 4. Verify output
|
|
||||||
assert len(outputs) == 1
|
|
||||||
assert "token_ids" in outputs[0]
|
|
||||||
assert len(outputs[0]["token_ids"]) == 1
|
|
||||||
|
|
||||||
print("test_prefill: PASSED")
|
|
||||||
136
tests/test_quest_policy.py
Normal file
136
tests/test_quest_policy.py
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
"""
|
||||||
|
Test for QuestPolicy block selection with GQA (Grouped Query Attention).
|
||||||
|
|
||||||
|
Demonstrates the key limitation: scores are AVERAGED across heads,
|
||||||
|
so blocks strongly needed by one head but not others may be dropped.
|
||||||
|
|
||||||
|
This is the expected Quest behavior - not a bug.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from nanovllm.kvcache.sparse import (
|
||||||
|
create_sparse_policy,
|
||||||
|
SparsePolicyType,
|
||||||
|
PolicyContext,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Test: Per-Head Score Averaging in GQA
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Determine device (GPU if available, else CPU)
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
print(f"Running test on device: {device}")
|
||||||
|
|
||||||
|
# Setup: 2 KV heads, 4 query heads (GQA group_size=2)
|
||||||
|
# topk=2 to make selection competitive
|
||||||
|
|
||||||
|
quest = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=2, threshold_blocks=0)
|
||||||
|
quest.initialize(
|
||||||
|
num_layers=1,
|
||||||
|
num_kv_heads=2,
|
||||||
|
head_dim=4,
|
||||||
|
num_cpu_blocks=6,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device, # Metadata stored on GPU
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata = quest.metadata
|
||||||
|
|
||||||
|
def set_key(block_id, head_id, values):
|
||||||
|
"""Set both key_min and key_max to same values for deterministic scoring."""
|
||||||
|
# Values need to be on the same device as metadata
|
||||||
|
tensor = torch.tensor(values, device=device)
|
||||||
|
metadata.key_min[block_id, 0, head_id, :] = tensor
|
||||||
|
metadata.key_max[block_id, 0, head_id, :] = tensor
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Design: Different heads want different blocks
|
||||||
|
# ============================================================
|
||||||
|
#
|
||||||
|
# Query = [1,1,1,1] for all heads, so score = sum(key values)
|
||||||
|
#
|
||||||
|
# Block | Head 0 | Head 1 | Average | Result
|
||||||
|
# ------|--------|--------|---------|--------
|
||||||
|
# 0 | +4 | -4 | 0 | Head0 wants, Head1 doesn't → DROPPED
|
||||||
|
# 1 | -4 | +4 | 0 | Head1 wants, Head0 doesn't → DROPPED
|
||||||
|
# 2 | +4 | +4 | +4 | Both want → SELECTED (rank 1)
|
||||||
|
# 3 | +3 | +3 | +3 | Both want → SELECTED (rank 2)
|
||||||
|
# 4 | +4 | 0 | +2 | Head0 strongly wants, Head1 neutral → rank 3
|
||||||
|
# 5 | 0 | +4 | +2 | Head1 strongly wants, Head0 neutral → rank 3
|
||||||
|
|
||||||
|
# Block 0: Head 0 strongly wants, Head 1 strongly rejects
|
||||||
|
set_key(0, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
||||||
|
set_key(0, 1, [-1.0, -1.0, -1.0, -1.0]) # head1: -4
|
||||||
|
|
||||||
|
# Block 1: Head 1 strongly wants, Head 0 strongly rejects
|
||||||
|
set_key(1, 0, [-1.0, -1.0, -1.0, -1.0]) # head0: -4
|
||||||
|
set_key(1, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
||||||
|
|
||||||
|
# Block 2: Both heads want equally (highest average)
|
||||||
|
set_key(2, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
||||||
|
set_key(2, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
||||||
|
|
||||||
|
# Block 3: Both heads want moderately
|
||||||
|
set_key(3, 0, [0.75, 0.75, 0.75, 0.75]) # head0: +3
|
||||||
|
set_key(3, 1, [0.75, 0.75, 0.75, 0.75]) # head1: +3
|
||||||
|
|
||||||
|
# Block 4: Head 0 strongly wants, Head 1 neutral
|
||||||
|
set_key(4, 0, [1.0, 1.0, 1.0, 1.0]) # head0: +4
|
||||||
|
set_key(4, 1, [0.0, 0.0, 0.0, 0.0]) # head1: 0
|
||||||
|
|
||||||
|
# Block 5: Head 1 strongly wants, Head 0 neutral
|
||||||
|
set_key(5, 0, [0.0, 0.0, 0.0, 0.0]) # head0: 0
|
||||||
|
set_key(5, 1, [1.0, 1.0, 1.0, 1.0]) # head1: +4
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Run selection
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Query on same device as metadata
|
||||||
|
query = torch.ones(1, 4, 4, device=device) # GQA: 4 query heads → 2 KV heads
|
||||||
|
|
||||||
|
ctx = PolicyContext(
|
||||||
|
query_chunk_idx=0,
|
||||||
|
num_query_chunks=1,
|
||||||
|
layer_id=0,
|
||||||
|
query=query,
|
||||||
|
is_prefill=False,
|
||||||
|
block_size=1024,
|
||||||
|
total_kv_len=6144,
|
||||||
|
)
|
||||||
|
|
||||||
|
available = list(range(6))
|
||||||
|
selected = quest.select_blocks(available, ctx)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Verify: Averaging behavior
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# topk=2, so only blocks 2 (+4 avg) and 3 (+3 avg) should be selected
|
||||||
|
assert len(selected) == 2, f"Expected 2 blocks, got {len(selected)}"
|
||||||
|
assert selected == [2, 3], f"Expected [2, 3], got {selected}"
|
||||||
|
|
||||||
|
# Key insight: blocks 0 and 1 have score +4 for ONE head,
|
||||||
|
# but they cancel out due to averaging with the other head's -4
|
||||||
|
assert 0 not in selected, "Block 0 should NOT be selected (head scores cancel out)"
|
||||||
|
assert 1 not in selected, "Block 1 should NOT be selected (head scores cancel out)"
|
||||||
|
|
||||||
|
# Blocks 4 and 5 have +4 for one head, 0 for other → avg=+2
|
||||||
|
# But +2 < +3 (block 3), so they don't make the top-2
|
||||||
|
assert 4 not in selected, "Block 4 avg=+2 < block 3 avg=+3"
|
||||||
|
assert 5 not in selected, "Block 5 avg=+2 < block 3 avg=+3"
|
||||||
|
|
||||||
|
print("✓ Block 2 selected: both heads want it (+4, +4) → avg=+4")
|
||||||
|
print("✓ Block 3 selected: both heads want it (+3, +3) → avg=+3")
|
||||||
|
print("✓ Block 0 NOT selected: head0=+4, head1=-4 → avg=0 (cancel out)")
|
||||||
|
print("✓ Block 1 NOT selected: head0=-4, head1=+4 → avg=0 (cancel out)")
|
||||||
|
print("✓ Block 4 NOT selected: head0=+4, head1=0 → avg=+2 (lower rank)")
|
||||||
|
print("✓ Block 5 NOT selected: head0=0, head1=+4 → avg=+2 (lower rank)")
|
||||||
|
|
||||||
|
# Verify metadata is on correct device
|
||||||
|
assert metadata.key_min.device.type == device.type, f"key_min on wrong device: {metadata.key_min.device}"
|
||||||
|
assert metadata.key_max.device.type == device.type, f"key_max on wrong device: {metadata.key_max.device}"
|
||||||
|
print(f"✓ Metadata stored on {device.type.upper()}")
|
||||||
|
|
||||||
|
print("\ntest_quest_policy: PASSED")
|
||||||
199
tests/test_sequential.py
Normal file
199
tests/test_sequential.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
"""
|
||||||
|
Sequential inference test for LLM.
|
||||||
|
|
||||||
|
Tests: After completing one prompt, the system can correctly handle
|
||||||
|
a second prompt with a clean state (first prompt's KV cache deallocated).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from nanovllm import LLM, SamplingParams
|
||||||
|
from utils import generate_needle_prompt, check_needle_answer
|
||||||
|
|
||||||
|
|
||||||
|
def run_sequential_test(
|
||||||
|
model_path: str,
|
||||||
|
max_model_len: int,
|
||||||
|
input_len: int,
|
||||||
|
num_gpu_blocks: int = 4,
|
||||||
|
block_size: int = 1024,
|
||||||
|
enable_cpu_offload: bool = False,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Run sequential inference test with two different prompts.
|
||||||
|
|
||||||
|
Each prompt has a different needle value. Both must be retrieved correctly.
|
||||||
|
"""
|
||||||
|
if verbose:
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Sequential Inference Test")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"Model: {model_path}")
|
||||||
|
print(f"Max model len: {max_model_len}")
|
||||||
|
print(f"Input length: {input_len}")
|
||||||
|
print(f"Block size: {block_size}")
|
||||||
|
print(f"CPU offload: {enable_cpu_offload}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
# Initialize LLM once
|
||||||
|
llm_kwargs = {
|
||||||
|
"enforce_eager": True,
|
||||||
|
"max_model_len": max_model_len,
|
||||||
|
"max_num_batched_tokens": max_model_len,
|
||||||
|
"enable_cpu_offload": enable_cpu_offload,
|
||||||
|
"kvcache_block_size": block_size,
|
||||||
|
}
|
||||||
|
if enable_cpu_offload:
|
||||||
|
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||||
|
|
||||||
|
llm = LLM(model_path, **llm_kwargs)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.6,
|
||||||
|
max_tokens=32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Test 1: First prompt with needle value "1234"
|
||||||
|
# ============================================================
|
||||||
|
needle_value_1 = "1234"
|
||||||
|
if verbose:
|
||||||
|
print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}")
|
||||||
|
|
||||||
|
prompt_1, expected_1 = generate_needle_prompt(
|
||||||
|
tokenizer=llm.tokenizer,
|
||||||
|
target_length=input_len,
|
||||||
|
needle_position=0.5,
|
||||||
|
needle_value=needle_value_1,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True)
|
||||||
|
output_text_1 = outputs_1[0]["text"]
|
||||||
|
passed_1 = check_needle_answer(output_text_1, expected_1)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" Expected: {expected_1}")
|
||||||
|
print(f" Output: {output_text_1[:100]}...")
|
||||||
|
print(f" Status: {'PASSED' if passed_1 else 'FAILED'}")
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Test 2: Second prompt with needle value "5678"
|
||||||
|
# ============================================================
|
||||||
|
needle_value_2 = "5678"
|
||||||
|
if verbose:
|
||||||
|
print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}")
|
||||||
|
|
||||||
|
prompt_2, expected_2 = generate_needle_prompt(
|
||||||
|
tokenizer=llm.tokenizer,
|
||||||
|
target_length=input_len,
|
||||||
|
needle_position=0.5,
|
||||||
|
needle_value=needle_value_2,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True)
|
||||||
|
output_text_2 = outputs_2[0]["text"]
|
||||||
|
passed_2 = check_needle_answer(output_text_2, expected_2)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" Expected: {expected_2}")
|
||||||
|
print(f" Output: {output_text_2[:100]}...")
|
||||||
|
print(f" Status: {'PASSED' if passed_2 else 'FAILED'}")
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Test 3: Third prompt - repeat first needle to ensure no cross-contamination
|
||||||
|
# ============================================================
|
||||||
|
needle_value_3 = "9999"
|
||||||
|
if verbose:
|
||||||
|
print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}")
|
||||||
|
|
||||||
|
prompt_3, expected_3 = generate_needle_prompt(
|
||||||
|
tokenizer=llm.tokenizer,
|
||||||
|
target_length=input_len,
|
||||||
|
needle_position=0.5,
|
||||||
|
needle_value=needle_value_3,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True)
|
||||||
|
output_text_3 = outputs_3[0]["text"]
|
||||||
|
passed_3 = check_needle_answer(output_text_3, expected_3)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" Expected: {expected_3}")
|
||||||
|
print(f" Output: {output_text_3[:100]}...")
|
||||||
|
print(f" Status: {'PASSED' if passed_3 else 'FAILED'}")
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Summary
|
||||||
|
# ============================================================
|
||||||
|
all_passed = passed_1 and passed_2 and passed_3
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"Summary")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}")
|
||||||
|
print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}")
|
||||||
|
print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}")
|
||||||
|
print(f"Overall: {'PASSED' if all_passed else 'FAILED'}")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
return all_passed
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Sequential inference test")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model", "-m",
|
||||||
|
type=str,
|
||||||
|
default=os.path.expanduser("~/models/Qwen3-0.6B/"),
|
||||||
|
help="Path to model"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-model-len",
|
||||||
|
type=int,
|
||||||
|
default=36 * 1024,
|
||||||
|
help="Maximum model context length"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-len",
|
||||||
|
type=int,
|
||||||
|
default=8 * 1024,
|
||||||
|
help="Target input sequence length"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-gpu-blocks",
|
||||||
|
type=int,
|
||||||
|
default=2,
|
||||||
|
help="Number of GPU blocks for CPU offload"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--block-size",
|
||||||
|
type=int,
|
||||||
|
default=1024,
|
||||||
|
help="KV cache block size"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-offload",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable CPU offload"
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
passed = run_sequential_test(
|
||||||
|
model_path=args.model,
|
||||||
|
max_model_len=args.max_model_len,
|
||||||
|
input_len=args.input_len,
|
||||||
|
num_gpu_blocks=args.num_gpu_blocks,
|
||||||
|
block_size=args.block_size,
|
||||||
|
enable_cpu_offload=args.enable_offload,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if passed:
|
||||||
|
print("test_sequential: PASSED")
|
||||||
|
else:
|
||||||
|
print("test_sequential: FAILED")
|
||||||
|
exit(1)
|
||||||
@@ -1,176 +0,0 @@
|
|||||||
"""
|
|
||||||
Tests for CUDA sgDMA (cudaMemcpy2D) extension.
|
|
||||||
|
|
||||||
Author: Zijie Tian
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import time
|
|
||||||
from nanovllm.comm import memcpy_2d
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
num_layers = 32
|
|
||||||
num_blocks = 10
|
|
||||||
block_size = 4096
|
|
||||||
num_kv_heads = 8
|
|
||||||
head_dim = 128
|
|
||||||
dtype = torch.float16
|
|
||||||
|
|
||||||
@property
|
|
||||||
def features_per_block(self):
|
|
||||||
return self.block_size * self.num_kv_heads * self.head_dim
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bytes_per_block(self):
|
|
||||||
return self.features_per_block * self.dtype.itemsize
|
|
||||||
|
|
||||||
@property
|
|
||||||
def bytes_per_layer(self):
|
|
||||||
return self.num_blocks * self.bytes_per_block
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Performance Benchmark
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def benchmark_sgdma():
|
|
||||||
"""Benchmark cudaMemcpy2D vs standard PyTorch methods."""
|
|
||||||
print("\n=== Performance Benchmark ===")
|
|
||||||
|
|
||||||
cfg = Config()
|
|
||||||
|
|
||||||
print(f" Configuration:")
|
|
||||||
print(f" num_layers: {cfg.num_layers}")
|
|
||||||
print(f" num_blocks: {cfg.num_blocks}")
|
|
||||||
print(f" block_size: {cfg.block_size}")
|
|
||||||
print(f" dtype: {cfg.dtype}")
|
|
||||||
print(f" bytes_per_block: {cfg.bytes_per_block / 1024:.1f} KB")
|
|
||||||
print(f" total transfer size: {cfg.num_layers * cfg.bytes_per_block / 1024 / 1024:.1f} MB")
|
|
||||||
|
|
||||||
num_iterations = 10
|
|
||||||
warmup = 3
|
|
||||||
test_block_id = 5
|
|
||||||
|
|
||||||
# Allocate memory
|
|
||||||
cpu_strided = torch.randn(
|
|
||||||
cfg.num_layers,
|
|
||||||
cfg.num_blocks,
|
|
||||||
cfg.features_per_block,
|
|
||||||
dtype=cfg.dtype,
|
|
||||||
pin_memory=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Method A: cudaMemcpy2D with sgDMA
|
|
||||||
# ========================================
|
|
||||||
gpu_buffer_a = torch.empty(cfg.num_layers, cfg.features_per_block, dtype=cfg.dtype, device='cuda')
|
|
||||||
|
|
||||||
spitch = cfg.bytes_per_layer
|
|
||||||
dpitch = cfg.bytes_per_block
|
|
||||||
width = cfg.bytes_per_block
|
|
||||||
height = cfg.num_layers
|
|
||||||
src_view = cpu_strided[:, test_block_id, :]
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
for _ in range(warmup):
|
|
||||||
memcpy_2d(gpu_buffer_a, src_view, dpitch, spitch, width, height, "h2d")
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Benchmark
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(num_iterations):
|
|
||||||
memcpy_2d(gpu_buffer_a, src_view, dpitch, spitch, width, height, "h2d")
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elapsed_a = time.perf_counter() - start
|
|
||||||
|
|
||||||
avg_time_a = elapsed_a / num_iterations * 1000 # ms
|
|
||||||
bandwidth_a = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_a
|
|
||||||
|
|
||||||
print(f"\n Method A (cudaMemcpy2D sgDMA):")
|
|
||||||
print(f" Avg time: {avg_time_a:.3f} ms")
|
|
||||||
print(f" Bandwidth: {bandwidth_a:.2f} GB/s")
|
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Method B: PyTorch .cuda() on strided view
|
|
||||||
# ========================================
|
|
||||||
# Warmup
|
|
||||||
for _ in range(warmup):
|
|
||||||
_ = cpu_strided[:, test_block_id, :].cuda()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Benchmark
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(num_iterations):
|
|
||||||
_ = cpu_strided[:, test_block_id, :].cuda()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elapsed_b = time.perf_counter() - start
|
|
||||||
|
|
||||||
avg_time_b = elapsed_b / num_iterations * 1000 # ms
|
|
||||||
bandwidth_b = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_b
|
|
||||||
|
|
||||||
print(f"\n Method B (PyTorch .cuda() on strided):")
|
|
||||||
print(f" Avg time: {avg_time_b:.3f} ms")
|
|
||||||
print(f" Bandwidth: {bandwidth_b:.2f} GB/s")
|
|
||||||
|
|
||||||
# ========================================
|
|
||||||
# Method C: PyTorch .cuda() on contiguous (pinned)
|
|
||||||
# ========================================
|
|
||||||
# Create contiguous version with pinned memory
|
|
||||||
cpu_contiguous = torch.empty(
|
|
||||||
cfg.num_layers,
|
|
||||||
cfg.features_per_block,
|
|
||||||
dtype=cfg.dtype,
|
|
||||||
pin_memory=True
|
|
||||||
)
|
|
||||||
cpu_contiguous.copy_(cpu_strided[:, test_block_id, :])
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
for _ in range(warmup):
|
|
||||||
_ = cpu_contiguous.cuda()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Benchmark
|
|
||||||
start = time.perf_counter()
|
|
||||||
for _ in range(num_iterations):
|
|
||||||
_ = cpu_contiguous.cuda()
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elapsed_c = time.perf_counter() - start
|
|
||||||
|
|
||||||
avg_time_c = elapsed_c / num_iterations * 1000 # ms
|
|
||||||
bandwidth_c = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_c
|
|
||||||
|
|
||||||
print(f"\n Method C (PyTorch .cuda() on contiguous):")
|
|
||||||
print(f" Avg time: {avg_time_c:.3f} ms")
|
|
||||||
print(f" Bandwidth: {bandwidth_c:.2f} GB/s")
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print(f"\n ========================================")
|
|
||||||
print(f" Performance Summary:")
|
|
||||||
print(f" Method A vs Method B: {bandwidth_a / bandwidth_b:.2f}x speedup")
|
|
||||||
print(f" Method A vs Method C: {bandwidth_a / bandwidth_c * 100:.2f}%")
|
|
||||||
print(f" ========================================")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("=== CUDA sgDMA (cudaMemcpy2D) Benchmark ===")
|
|
||||||
|
|
||||||
# Check CUDA availability
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available. Skipping benchmark.")
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
# Print GPU info
|
|
||||||
print(f"Using GPU: {torch.cuda.get_device_name()}")
|
|
||||||
|
|
||||||
# Run benchmark
|
|
||||||
benchmark_sgdma()
|
|
||||||
|
|
||||||
print("\n=== Benchmark Complete ===")
|
|
||||||
181
tests/utils.py
Normal file
181
tests/utils.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
"""
|
||||||
|
Test utilities for nano-vllm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Needle-in-Haystack Test Utilities
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# Haystack filler paragraphs (various topics to create realistic context)
|
||||||
|
HAYSTACK_PARAGRAPHS = [
|
||||||
|
"The weather today is quite pleasant with clear skies and moderate temperatures. "
|
||||||
|
"Many people are enjoying outdoor activities in the park. "
|
||||||
|
"Birds are singing in the trees and children are playing on the swings. ",
|
||||||
|
|
||||||
|
"In the world of technology, new innovations continue to emerge every day. "
|
||||||
|
"Researchers are working on advanced algorithms and computing systems. "
|
||||||
|
"The future of artificial intelligence looks promising with many breakthroughs. ",
|
||||||
|
|
||||||
|
"The history of human civilization spans thousands of years. "
|
||||||
|
"Ancient cultures developed writing, mathematics, and astronomy. "
|
||||||
|
"Trade routes connected distant lands and facilitated cultural exchange. ",
|
||||||
|
|
||||||
|
"Modern cooking combines traditional techniques with new ingredients. "
|
||||||
|
"Chefs around the world experiment with flavors and presentations. "
|
||||||
|
"Food brings people together and creates memorable experiences. ",
|
||||||
|
|
||||||
|
"The ocean covers more than seventy percent of Earth's surface. "
|
||||||
|
"Marine ecosystems support an incredible diversity of life forms. "
|
||||||
|
"Scientists continue to discover new species in the deep sea. ",
|
||||||
|
|
||||||
|
"Music has been a part of human culture since prehistoric times. "
|
||||||
|
"Different genres evolved across various regions and time periods. "
|
||||||
|
"Today, people can access millions of songs through digital platforms. ",
|
||||||
|
|
||||||
|
"Space exploration has revealed many secrets about our universe. "
|
||||||
|
"Telescopes can observe galaxies billions of light years away. "
|
||||||
|
"Future missions aim to establish human presence on other planets. ",
|
||||||
|
|
||||||
|
"The study of languages reveals patterns in human cognition. "
|
||||||
|
"Linguists analyze grammar, semantics, and phonetics across cultures. "
|
||||||
|
"Language continues to evolve with new words and expressions. ",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_needle_prompt(
|
||||||
|
tokenizer,
|
||||||
|
target_length: int,
|
||||||
|
needle_position: float = 0.5,
|
||||||
|
needle_value: str = "7492",
|
||||||
|
use_chat_template: bool = True,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Generate a needle-in-haystack prompt of exactly target_length tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer: HuggingFace tokenizer for length estimation
|
||||||
|
target_length: Target total sequence length in tokens
|
||||||
|
needle_position: Where to place needle (0.0=start, 0.5=middle, 1.0=end)
|
||||||
|
needle_value: The secret value to hide in the haystack
|
||||||
|
use_chat_template: Whether to use chat template for instruct models
|
||||||
|
verbose: Whether to print generation info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(prompt, expected_answer): The full prompt and the expected needle value
|
||||||
|
"""
|
||||||
|
# The needle sentence
|
||||||
|
needle = f"The secret number you need to remember is {needle_value}. This is very important. "
|
||||||
|
|
||||||
|
# Question text
|
||||||
|
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||||
|
question_text = "/no_think Answer only with the secret number mentioned above, nothing else:"
|
||||||
|
else:
|
||||||
|
question_text = "\n\nQuestion: What is the secret number mentioned in the text above?\nAnswer: The secret number is"
|
||||||
|
|
||||||
|
def build_prompt(haystack_parts, needle_idx):
|
||||||
|
"""Build full prompt from haystack parts with needle inserted."""
|
||||||
|
parts = haystack_parts.copy()
|
||||||
|
parts.insert(needle_idx, needle)
|
||||||
|
full_text = "".join(parts)
|
||||||
|
|
||||||
|
if use_chat_template and hasattr(tokenizer, 'apply_chat_template'):
|
||||||
|
messages = [{"role": "user", "content": f"{full_text}\n\n{question_text}"}]
|
||||||
|
return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
|
else:
|
||||||
|
return full_text + question_text
|
||||||
|
|
||||||
|
def count_tokens(prompt):
|
||||||
|
return len(tokenizer.encode(prompt, add_special_tokens=False))
|
||||||
|
|
||||||
|
def get_needle_idx(parts):
|
||||||
|
idx = int(len(parts) * needle_position)
|
||||||
|
return max(0, min(idx, len(parts)))
|
||||||
|
|
||||||
|
# Pre-compute tokens per paragraph for efficiency (avoid O(n²) tokenization)
|
||||||
|
para_tokens = []
|
||||||
|
for para in HAYSTACK_PARAGRAPHS:
|
||||||
|
para_tokens.append(len(tokenizer.encode(para, add_special_tokens=False)))
|
||||||
|
avg_tokens_per_para = sum(para_tokens) / len(para_tokens)
|
||||||
|
|
||||||
|
# Estimate overhead (needle + question + chat template)
|
||||||
|
overhead_prompt = build_prompt([HAYSTACK_PARAGRAPHS[0]], 0)
|
||||||
|
overhead_tokens = count_tokens(overhead_prompt) - para_tokens[0]
|
||||||
|
|
||||||
|
# Phase 1: Estimate number of paragraphs needed
|
||||||
|
estimated_paras = int((target_length - overhead_tokens) / avg_tokens_per_para) + 1
|
||||||
|
|
||||||
|
# Build haystack with estimated paragraphs
|
||||||
|
haystack_parts = []
|
||||||
|
for i in range(estimated_paras):
|
||||||
|
haystack_parts.append(HAYSTACK_PARAGRAPHS[i % len(HAYSTACK_PARAGRAPHS)])
|
||||||
|
|
||||||
|
# Phase 2: Adjust to get closer to target
|
||||||
|
prompt = build_prompt(haystack_parts, get_needle_idx(haystack_parts))
|
||||||
|
current_tokens = count_tokens(prompt)
|
||||||
|
|
||||||
|
# Add more if under target
|
||||||
|
para_idx = len(haystack_parts)
|
||||||
|
while current_tokens < target_length and para_idx < 100000:
|
||||||
|
para = HAYSTACK_PARAGRAPHS[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||||
|
haystack_parts.append(para)
|
||||||
|
current_tokens += para_tokens[para_idx % len(HAYSTACK_PARAGRAPHS)]
|
||||||
|
para_idx += 1
|
||||||
|
|
||||||
|
# Remove if over target
|
||||||
|
while current_tokens > target_length + 100 and len(haystack_parts) > 1:
|
||||||
|
removed_para_idx = (len(haystack_parts) - 1) % len(HAYSTACK_PARAGRAPHS)
|
||||||
|
haystack_parts.pop()
|
||||||
|
current_tokens -= para_tokens[removed_para_idx]
|
||||||
|
|
||||||
|
# Final build
|
||||||
|
needle_idx = get_needle_idx(haystack_parts)
|
||||||
|
prompt = build_prompt(haystack_parts, needle_idx)
|
||||||
|
|
||||||
|
actual_tokens = count_tokens(prompt)
|
||||||
|
if verbose:
|
||||||
|
print(f"[NeedleTest] Target: {target_length}, Actual: {actual_tokens} tokens (diff={actual_tokens - target_length})")
|
||||||
|
|
||||||
|
return prompt, needle_value
|
||||||
|
|
||||||
|
|
||||||
|
def check_needle_answer(output_text: str, expected: str) -> bool:
|
||||||
|
"""Check if the model output contains the expected needle value."""
|
||||||
|
# Clean output - remove special tokens and whitespace
|
||||||
|
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
||||||
|
output_clean = ' '.join(output_clean.split()).lower()
|
||||||
|
expected_clean = expected.strip().lower()
|
||||||
|
|
||||||
|
# Check if expected value appears in output
|
||||||
|
# Also try to find it as a standalone number
|
||||||
|
if expected_clean in output_clean:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Try to extract numbers and check if expected is among them
|
||||||
|
numbers = re.findall(r'\d+', output_clean)
|
||||||
|
return expected_clean in numbers
|
||||||
|
|
||||||
|
|
||||||
|
def generate_random_token_ids(
|
||||||
|
length: int,
|
||||||
|
vocab_size: int = 10000,
|
||||||
|
seed: int = 42,
|
||||||
|
) -> list:
|
||||||
|
"""
|
||||||
|
Generate random token IDs for testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length: Number of tokens to generate
|
||||||
|
vocab_size: Maximum token ID (exclusive)
|
||||||
|
seed: Random seed for reproducibility
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of random token IDs
|
||||||
|
"""
|
||||||
|
from random import randint, seed as set_seed
|
||||||
|
set_seed(seed)
|
||||||
|
return [randint(0, vocab_size - 1) for _ in range(length)]
|
||||||
Reference in New Issue
Block a user