Compare commits
5 Commits
tzj/vs_off
...
d9890aa2cd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9890aa2cd | ||
|
|
5a837c8c83 | ||
|
|
d1bbb7efe2 | ||
|
|
1a78ae74d5 | ||
|
|
c254c8c330 |
@@ -66,27 +66,33 @@ print("test_xxx: PASSED")
|
|||||||
|
|
||||||
## Running Tests
|
## Running Tests
|
||||||
|
|
||||||
Use PYTHONPATH for multi-instance isolation (no pip install needed):
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run a specific test
|
# Run a specific test
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_offload_engine.py
|
python tests/test_offload_engine.py
|
||||||
|
|
||||||
# Run with specific GPU
|
# Run with specific GPU
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_ring_buffer.py
|
CUDA_VISIBLE_DEVICES=0 python tests/test_ring_buffer.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Benchmarks
|
## Benchmarks
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python bench.py
|
# Standard GPU benchmark
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python bench_offload.py
|
python bench.py
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python bench_vllm.py
|
|
||||||
|
# CPU offload benchmark
|
||||||
|
python bench_offload.py
|
||||||
|
|
||||||
|
# vLLM comparison benchmark
|
||||||
|
python bench_vllm.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Verification
|
## Quick Verification
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Import test
|
# Import test
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python -c "from nanovllm import LLM"
|
python -c "from nanovllm import LLM"
|
||||||
|
|
||||||
|
# Run offload benchmark (tests CPU-primary ring buffer mode)
|
||||||
|
python bench_offload.py
|
||||||
```
|
```
|
||||||
|
|||||||
2
.gitmodules
vendored
2
.gitmodules
vendored
@@ -1,4 +1,4 @@
|
|||||||
[submodule "3rdparty/Block-SparseAttention"]
|
[submodule "3rdparty/Block-SparseAttention"]
|
||||||
path = 3rdparty/Block-SparseAttention
|
path = 3rdparty/Block-SparseAttention
|
||||||
url = https://github.com/Zijie-Tian/Block-SparseAttention.git
|
url = https://github.com/Zijie-Tian/Block-Sparse-Attention.git
|
||||||
branch = tzj/minference
|
branch = tzj/minference
|
||||||
|
|||||||
519
CLAUDE.md
519
CLAUDE.md
@@ -4,79 +4,444 @@ This file provides guidance to Claude Code when working with this repository.
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports multiple model architectures (Qwen3, Qwen2, Llama) with CPU offload for long-context inference.
|
Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline LLM inference. Supports Qwen3 models with CPU offload for long-context inference.
|
||||||
|
|
||||||
## GPU Mutex for Multi-Instance Debugging
|
## GPU Mutex for Multi-Instance Debugging
|
||||||
|
|
||||||
**IMPORTANT**: When running multiple Claude instances for parallel debugging, different rules apply based on script type:
|
**IMPORTANT**: When running multiple Claude instances for parallel debugging, only one GPU (cuda:0) is available. Before executing ANY command that uses the GPU (python scripts, benchmarks, tests), Claude MUST:
|
||||||
|
|
||||||
### Benchmarks (`bench*.py`) - Exclusive GPU Access Required
|
1. **Check GPU availability** by running:
|
||||||
|
```bash
|
||||||
|
nvidia-smi --query-compute-apps=pid,name,used_memory --format=csv,noheader
|
||||||
|
```
|
||||||
|
|
||||||
Before running any `bench*.py` script, Claude MUST wait for exclusive GPU access:
|
2. **If processes are running on GPU**:
|
||||||
|
- Wait and retry every 10 seconds until GPU is free
|
||||||
|
- Use this polling loop:
|
||||||
|
```bash
|
||||||
|
while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do
|
||||||
|
echo "GPU busy, waiting 10s..."
|
||||||
|
sleep 10
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Only proceed** when `nvidia-smi --query-compute-apps=pid --format=csv,noheader` returns empty output
|
||||||
|
|
||||||
|
**Example workflow**:
|
||||||
```bash
|
```bash
|
||||||
# Check and wait for GPU to be free
|
# First check if GPU is in use
|
||||||
while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do
|
nvidia-smi --query-compute-apps=pid,name,used_memory --format=csv,noheader
|
||||||
echo "GPU busy, waiting 10s..."
|
|
||||||
sleep 10
|
# If output is empty, proceed with your command
|
||||||
done
|
python bench_offload.py
|
||||||
|
|
||||||
|
# If output shows processes, wait until they finish
|
||||||
```
|
```
|
||||||
|
|
||||||
### Other Scripts (tests, examples) - No Special Requirements
|
**Note**: This applies to ALL GPU operations including:
|
||||||
|
- Running tests (`python tests/test_*.py`)
|
||||||
|
- Running benchmarks (`python bench*.py`)
|
||||||
|
- Running examples (`python example.py`)
|
||||||
|
- Any script that imports torch/cuda
|
||||||
|
|
||||||
For non-benchmark scripts, exclusive GPU access is NOT required. Multiple nanovllm processes can run simultaneously on different GPUs - each process automatically selects a unique port for `torch.distributed` communication.
|
## Local Package Installation for Multi-Instance
|
||||||
|
|
||||||
## Multi-Instance Development with PYTHONPATH
|
**CRITICAL**: After ANY code modification in the `nanovllm/` directory, you MUST reinstall the package before running tests or benchmarks:
|
||||||
|
|
||||||
**IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances.
|
|
||||||
|
|
||||||
**Use PYTHONPATH directly** - no pip install needed:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Set PYTHONPATH to point to the project root directory
|
pip install -e . --prefix=./.local --no-deps
|
||||||
PYTHONPATH=/path/to/your/worktree:$PYTHONPATH python <script.py>
|
|
||||||
|
|
||||||
# Example: running tests
|
|
||||||
PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Benefits**:
|
Then run with PYTHONPATH:
|
||||||
- No `pip install` required
|
```bash
|
||||||
- Code changes take effect immediately (no reinstall needed)
|
PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python <script.py>
|
||||||
- Each worktree is completely isolated
|
```
|
||||||
|
|
||||||
## Documentation Index
|
**IMPORTANT**: When running multiple Claude instances on different worktrees, do NOT use `pip install -e .` globally as it will affect other instances. Instead, use local installation:
|
||||||
|
|
||||||
| Document | Purpose |
|
1. **Install to worktree-local directory**:
|
||||||
|----------|---------|
|
```bash
|
||||||
| [`docs/architecture_guide.md`](docs/architecture_guide.md) | Core components, layer-wise CPU offload design, prefill/decode flows, implementation details |
|
pip install -e . --prefix=./.local --no-deps
|
||||||
| [`docs/multi_model_support.md`](docs/multi_model_support.md) | Model registry system, adding new models (Qwen3/Llama), architecture differences, RoPE scaling |
|
```
|
||||||
| [`docs/cuda_graph_offload_guide.md`](docs/cuda_graph_offload_guide.md) | CUDA graph support for CPU offload decode path, 4x decode speedup |
|
|
||||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (MInference, FlexPrefill, XAttention, Quest), computation flow |
|
2. **Set PYTHONPATH before running any Python command**:
|
||||||
| [`docs/block_sparse_attention_lib.md`](docs/block_sparse_attention_lib.md) | MIT-Han-Lab Block-Sparse-Attention library reference: sparse modes, API, performance |
|
```bash
|
||||||
| [`docs/sparse_prefill_integration_plan.md`](docs/sparse_prefill_integration_plan.md) | Integration plan for MInference/XAttention/FlexPrefill with unified BlockMask interface |
|
export PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH
|
||||||
| [`docs/sparse_offload_integration.md`](docs/sparse_offload_integration.md) | Sparse policy integration with layerwise offload, `requires_block_selection` interface design |
|
```
|
||||||
| [`docs/layerwise_offload_memory_analysis.md`](docs/layerwise_offload_memory_analysis.md) | Memory allocation analysis with theoretical formulas and empirical validation (< 5% error) |
|
|
||||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, tensor comparison, memory profiling |
|
3. **Combined example**:
|
||||||
| [`docs/gpu_only_performance_issue.md`](docs/gpu_only_performance_issue.md) | GPU-only mode slower than offload due to PagedAttention scatter overhead, optimization proposals |
|
```bash
|
||||||
| [`docs/offload_accuracy_issue.md`](docs/offload_accuracy_issue.md) | **BUG**: CPU offload mode 66% accuracy vs 100% non-offload on RULER NIAH benchmark |
|
# One-liner for running tests with local package
|
||||||
| [`docs/64k_memory_analysis.md`](docs/64k_memory_analysis.md) | 64k inference memory analysis: GPU-only vs offload, OOM root cause (fragmentation), RTX 3090 limitations |
|
PYTHONPATH=./.local/lib/python3.10/site-packages:$PYTHONPATH python tests/test_needle.py
|
||||||
| [`docs/xattention_integration.md`](docs/xattention_integration.md) | XAttention integration guide: algorithm, implementation, design decisions, and testing |
|
```
|
||||||
| [`docs/xattention_analysis.md`](docs/xattention_analysis.md) | XAttention algorithm analysis: chunked estimation, block sparse attention, integration design |
|
|
||||||
| [`docs/development_notes.md`](docs/development_notes.md) | Development notes and scratchpad for ongoing work |
|
**Note**: The Python version in the path (python3.10) should match your environment.
|
||||||
| [`docs/chunked_prefill_analysis.md`](docs/chunked_prefill_analysis.md) | **NEW**: Chunked prefill for ultra-long sequences (1M+), memory analysis, MLP activation breakdown, implementation guide |
|
|
||||||
|
**CRITICAL**: After making code changes to `nanovllm/` source files, you MUST reinstall the package for changes to take effect:
|
||||||
|
```bash
|
||||||
|
pip install -e . --prefix=./.local --no-deps
|
||||||
|
```
|
||||||
|
Without reinstallation, Python will use the old cached version and your changes will NOT be reflected!
|
||||||
|
|
||||||
|
## Sparse Attention
|
||||||
|
|
||||||
|
For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md).
|
||||||
|
|
||||||
|
### Quest Sparse Policy
|
||||||
|
|
||||||
|
**Files**: `nanovllm/kvcache/sparse/quest.py`, `nanovllm/kvcache/sparse/policy.py`
|
||||||
|
|
||||||
|
Quest policy selects Top-K blocks based on query-key similarity bounds using min/max key metadata.
|
||||||
|
|
||||||
|
**Scoring Mechanism**:
|
||||||
|
```python
|
||||||
|
score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads]
|
||||||
|
score_max = torch.einsum('hd,bhd->bh', q, key_max) # [num_blocks, kv_heads]
|
||||||
|
scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks] ← averaged!
|
||||||
|
```
|
||||||
|
|
||||||
|
**Critical Limitation - No Per-Head Scheduling**:
|
||||||
|
|
||||||
|
The `.mean(dim=-1)` averages scores across all heads, making a **unified** block selection for all heads:
|
||||||
|
|
||||||
|
```
|
||||||
|
Block A: head0 needs (+4), head1 doesn't (-4) → avg = 0 → NOT selected
|
||||||
|
Block B: head0 doesn't (-4), head1 needs (+4) → avg = 0 → NOT selected
|
||||||
|
Block C: both heads moderately need (+2, +2) → avg = +2 → selected
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why Per-Head Scheduling is Infeasible**:
|
||||||
|
1. **Memory Layout**: GPU cache stores all heads together `[block_size, kv_heads, head_dim]`
|
||||||
|
2. **FlashAttention**: Requires complete heads - partial heads cause dimension mismatch
|
||||||
|
3. **Block Granularity**: If any head needs a block, the entire block (all heads) must be loaded
|
||||||
|
|
||||||
|
**Policy Types**:
|
||||||
|
- `FullAttentionPolicy`: `supports_prefill=True, supports_decode=True` - loads all blocks
|
||||||
|
- `QuestPolicy`: `supports_prefill=False, supports_decode=True` - decode-only Top-K selection
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
### Core Components
|
||||||
|
|
||||||
|
- **LLMEngine** (`llm_engine.py`): Main entry, runs prefill-decode loop
|
||||||
|
- **ModelRunner** (`model_runner.py`): Loads weights, allocates KV cache, CUDA graphs
|
||||||
|
- **Scheduler** (`scheduler.py`): Two-phase scheduling (prefill → decode)
|
||||||
|
- **BlockManager** (`block_manager.py`): Paged attention with prefix caching (xxhash), default block size 4096
|
||||||
|
- **Attention** (`layers/attention.py`): FlashAttention with chunked methods for CPU offload
|
||||||
|
|
||||||
|
## PyTorch Hooks for Debugging
|
||||||
|
|
||||||
|
### Hook Positions in Qwen3
|
||||||
|
|
||||||
|
```
|
||||||
|
decoder_layer
|
||||||
|
├── input_layernorm (RMSNorm)
|
||||||
|
├── self_attn (Qwen3Attention) ← Hook here for attention I/O after o_proj
|
||||||
|
│ ├── q_proj → q_norm → RoPE
|
||||||
|
│ ├── k_proj → k_norm → RoPE
|
||||||
|
│ ├── v_proj
|
||||||
|
│ ├── attn (Attention) ← Hook here for Q/K/V tensors
|
||||||
|
│ │ └── FlashAttention / SDPA
|
||||||
|
│ └── o_proj
|
||||||
|
├── post_attention_layernorm (RMSNorm)
|
||||||
|
└── mlp (Qwen3MLP)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Hook Types & Data Shapes
|
||||||
|
|
||||||
|
| Hook Position | Type | Captured Data |
|
||||||
|
|---------------|------|---------------|
|
||||||
|
| `self_attn` | post | `[batch, seq_len, hidden_size]` - after o_proj |
|
||||||
|
| `self_attn.attn` | pre | Q,K,V: `[seq_len, num_heads, head_dim]` - after RoPE |
|
||||||
|
| `self_attn.attn` | post | `[seq_len, num_heads, head_dim]` - before o_proj |
|
||||||
|
|
||||||
|
### Example: Capture Attention Outputs
|
||||||
|
|
||||||
|
```python
|
||||||
|
storage = {}
|
||||||
|
|
||||||
|
def make_hook(layer_id: int, storage: dict):
|
||||||
|
def hook(module, inputs, output):
|
||||||
|
if isinstance(output, tuple):
|
||||||
|
attn_output = output[0]
|
||||||
|
else:
|
||||||
|
attn_output = output
|
||||||
|
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
|
||||||
|
if attn_output.dim() == 2:
|
||||||
|
attn_output = attn_output.unsqueeze(0)
|
||||||
|
storage[layer_id] = attn_output.detach().clone()
|
||||||
|
return hook
|
||||||
|
|
||||||
|
# Register hooks
|
||||||
|
hooks = []
|
||||||
|
for layer_idx, layer in enumerate(model.model.layers):
|
||||||
|
hooks.append(layer.self_attn.register_forward_hook(make_hook(layer_idx, storage)))
|
||||||
|
|
||||||
|
# Run inference...
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
for hook in hooks:
|
||||||
|
hook.remove()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Reference Implementation
|
||||||
|
|
||||||
|
Key files:
|
||||||
|
- `tests/modeling_qwen3.py`: Reference Qwen3 implementation (torch + transformers only)
|
||||||
|
- `tests/test_needle_ref.py`: Reference needle test using custom Qwen3
|
||||||
|
- `tests/test_needle.py`: Needle-in-haystack test for nanovllm
|
||||||
|
|
||||||
|
### Common Pitfalls
|
||||||
|
|
||||||
|
1. **Shape mismatch**: nanovllm uses `[num_tokens, ...]` while torch uses `[batch, seq_len, ...]`
|
||||||
|
2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj
|
||||||
|
3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]`
|
||||||
|
|
||||||
|
## CPU Offload System
|
||||||
|
|
||||||
|
### Ring Buffer Design
|
||||||
|
|
||||||
|
```
|
||||||
|
GPU Slots: [0] [1] [2] [3] ... (unified ring buffer)
|
||||||
|
Prefill: slot = chunk_idx % N
|
||||||
|
Decode: slot[0] = decode, slots[1:] = load previous chunks
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Files**: `kvcache/offload_engine.py`, `kvcache/hybrid_manager.py`
|
||||||
|
|
||||||
|
**Memory Layout**:
|
||||||
|
- GPU: `[num_layers, num_gpu_blocks, block_size, kv_heads, head_dim]`
|
||||||
|
- CPU: `[num_layers, num_cpu_blocks, ...]` (pinned memory)
|
||||||
|
|
||||||
|
**Key Methods**:
|
||||||
|
- `load_to_slot_layer(slot, layer, cpu_block)`: Async H2D load
|
||||||
|
- `offload_slot_to_cpu(slot, cpu_block)`: Async D2H offload
|
||||||
|
- Per-slot per-layer CUDA events for fine-grained synchronization
|
||||||
|
|
||||||
|
**Pipeline**: N-way pipeline with dedicated streams for full compute-transfer overlap. Pipeline depth = N-1 (prefill), (N-1)/2 (decode).
|
||||||
|
|
||||||
|
### Stream Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
Transfer Streams: [slot_0_stream] [slot_1_stream] ... [slot_N_stream]
|
||||||
|
↓ ↓ ↓
|
||||||
|
GPU Slots: [slot_0] [slot_1] ... [slot_N]
|
||||||
|
↓ ↓ ↓
|
||||||
|
Compute Stream: ←←←←←←←←←←←← [dedicated compute stream] →→→→→→→→→→→→
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key Design Decisions**:
|
||||||
|
- **Per-slot transfer streams**: Each GPU slot has its own CUDA stream for H2D transfers, enabling parallel loading
|
||||||
|
- **Dedicated compute stream**: Created with `torch.cuda.Stream()` (NOT `current_stream()`) to avoid implicit synchronization with default stream
|
||||||
|
- **CUDA Events**: `ring_slot_ready` (transfer complete), `ring_slot_compute_done` (safe to overwrite)
|
||||||
|
|
||||||
|
## Scatter-Gather DMA (sgDMA) - INTEGRATED ✓
|
||||||
|
|
||||||
|
### Problem & Solution
|
||||||
|
|
||||||
|
**Problem**: Strided CPU cache access `k_cache_cpu[:, block_id]` caused slow Device→Pageable transfers at ~1.4 GB/s instead of optimal ~24 GB/s pinned memory bandwidth.
|
||||||
|
|
||||||
|
**Solution**: Implemented `cudaMemcpy2D` via custom CUDA extension to handle strided layouts natively. **Integration complete** as of 2025-12-25.
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm.comm import memcpy_2d_async
|
||||||
|
|
||||||
|
# Transfer block_id across all layers
|
||||||
|
spitch = num_blocks * features * dtype_size # stride between layers
|
||||||
|
dpitch = features * dtype_size # contiguous destination
|
||||||
|
width = features * dtype_size # bytes per row
|
||||||
|
height = num_layers # number of rows
|
||||||
|
|
||||||
|
memcpy_2d_async(gpu_buf, cpu_cache[:, block_id], dpitch, spitch, width, height, "h2d", stream)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Benchmark Performance (Synthetic, 256MB)
|
||||||
|
|
||||||
|
| Method | Bandwidth | Speedup |
|
||||||
|
|--------|-----------|---------|
|
||||||
|
| **cudaMemcpy2D (sgDMA)** | **24.95 GB/s** | **Baseline** |
|
||||||
|
| PyTorch strided | 4.25 GB/s | **5.87x slower** |
|
||||||
|
| PyTorch contiguous | 24.92 GB/s | Same |
|
||||||
|
|
||||||
|
### Real-World Performance (A100, Attention Offload)
|
||||||
|
|
||||||
|
**Measured from `test_attention_offload.py` profiling**:
|
||||||
|
|
||||||
|
| Transfer Type | Count | Bandwidth | Previous | Speedup |
|
||||||
|
|---------------|-------|-----------|----------|---------|
|
||||||
|
| **Device→Pinned (D2H)** | 416 | **21.49 GB/s** | 1.40 GB/s | **15.35x** |
|
||||||
|
| **Pinned→Device (H2D)** | 24,960 | **23.39 GB/s** | N/A | N/A |
|
||||||
|
| Device→Pageable (D2H) | **0** | N/A | ~40 transfers | **Eliminated** |
|
||||||
|
|
||||||
|
**Verification**: All slow Device→Pageable transfers eliminated. System achieves near-optimal PCIe Gen3 x16 bandwidth.
|
||||||
|
|
||||||
|
**Build**: `python setup.py build_ext --inplace`
|
||||||
|
|
||||||
|
**Files**:
|
||||||
|
- `csrc/sgdma_kernel.cu`, `csrc/sgdma.cpp`: CUDA extension
|
||||||
|
- `nanovllm/comm/sgdma.py`: Python API
|
||||||
|
- `kvcache/offload_engine.py`: Integration (4 methods updated)
|
||||||
|
|
||||||
|
### Integration Details
|
||||||
|
|
||||||
|
**Modified methods in `offload_engine.py`**:
|
||||||
|
- `load_to_slot_all_layers()`: H2D ring buffer load
|
||||||
|
- `offload_slot_to_cpu()`: D2H ring buffer offload
|
||||||
|
- `offload_decode_slot()`: D2H decode slot offload
|
||||||
|
- `load_cpu_blocks_to_gpu_slots_all_layers()`: Batch H2D load
|
||||||
|
|
||||||
|
**Example replacement**:
|
||||||
|
```python
|
||||||
|
# Before (slow, Device→Pageable fallback)
|
||||||
|
self.k_cache_gpu[:, slot].copy_(self.k_cache_cpu[:, cpu_block], non_blocking=True)
|
||||||
|
|
||||||
|
# After (fast, Device→Pinned via sgDMA)
|
||||||
|
memcpy_2d_async(
|
||||||
|
self.k_cache_gpu[:, slot], self.k_cache_cpu[:, cpu_block],
|
||||||
|
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||||
|
"h2d", stream=self.transfer_stream_main
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Actual Impact**: 15.35x faster D2H transfers, eliminates memory transfer bottleneck. Expected 2-3x overall prefill throughput improvement.
|
||||||
|
|
||||||
|
## Online Softmax Merge - Triton Fused Kernel ✓
|
||||||
|
|
||||||
|
### Problem & Solution
|
||||||
|
|
||||||
|
**Problem**: Original PyTorch implementation of `merge_attention_outputs()` launches 7 separate kernels per merge operation:
|
||||||
|
1. `torch.maximum()` - max(lse1, lse2)
|
||||||
|
2. `torch.exp()` (2x) - exp(lse1-max), exp(lse2-max)
|
||||||
|
3. `transpose()` + `unsqueeze()` - reshape for broadcasting
|
||||||
|
4. Accumulation (6x) - weighted sum operations
|
||||||
|
5. Division - normalize output
|
||||||
|
6. `torch.log()` - merge LSE
|
||||||
|
7. `.to()` - type conversion
|
||||||
|
|
||||||
|
**Profiling revealed**: In ChunkedPrefill with 8 layers, these operations consumed **698 ms** GPU time (vs FlashAttention 603 ms), becoming a major bottleneck.
|
||||||
|
|
||||||
|
**Solution**: Implemented Triton fused kernels that combine all operations into 2 kernels. **Integration complete** as of 2025-12-25.
|
||||||
|
|
||||||
|
### Implementation
|
||||||
|
|
||||||
|
**File**: `nanovllm/kvcache/chunked_attention.py:278-408`
|
||||||
|
|
||||||
|
Two Triton kernels replace all PyTorch operations:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@triton.jit
|
||||||
|
def _merge_lse_kernel(...):
|
||||||
|
"""Fused: max + exp + log"""
|
||||||
|
max_lse = tl.maximum(lse1, lse2)
|
||||||
|
exp1 = tl.exp(lse1 - max_lse)
|
||||||
|
exp2 = tl.exp(lse2 - max_lse)
|
||||||
|
lse_merged = max_lse + tl.log(exp1 + exp2)
|
||||||
|
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def _merge_output_kernel(...):
|
||||||
|
"""Fused: broadcast + weighted sum + division"""
|
||||||
|
# Load LSE, compute scaling factors
|
||||||
|
exp1 = tl.exp(lse1 - max_lse)
|
||||||
|
exp2 = tl.exp(lse2 - max_lse)
|
||||||
|
sum_exp = exp1 + exp2
|
||||||
|
|
||||||
|
# Process headdim in chunks
|
||||||
|
for d_offset in range(0, headdim, BLOCK_SIZE):
|
||||||
|
o1_val = tl.load(o1_ptr + o_idx, mask=mask)
|
||||||
|
o2_val = tl.load(o2_ptr + o_idx, mask=mask)
|
||||||
|
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
|
||||||
|
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Performance Results
|
||||||
|
|
||||||
|
**From `test_attention_offload.py` profiling** (8 layers, 16K tokens, 16 chunks, 10 iterations):
|
||||||
|
|
||||||
|
| Metric | PyTorch (7 kernels) | Triton (2 kernels) | Speedup |
|
||||||
|
|--------|---------------------|---------------------|---------|
|
||||||
|
| **GPU time (8 layers)** | 698 ms | 160.7 ms | **4.3x** |
|
||||||
|
| **Per-layer time** | 87.3 ms | 20.1 ms | **4.3x** |
|
||||||
|
| **Avg per merge** | 56 µs | 12.9 µs | **4.3x** |
|
||||||
|
| **Kernel launches** | 10,920 | 3,120 | **71% reduction** |
|
||||||
|
|
||||||
|
**Breakdown** (per-layer, 1,560 merges):
|
||||||
|
- `_merge_output_kernel`: 126.9 ms / 8 = 15.9 ms/layer (avg 10.2 µs/call)
|
||||||
|
- `_merge_lse_kernel`: 33.8 ms / 8 = 4.2 ms/layer (avg 2.7 µs/call)
|
||||||
|
|
||||||
|
### Overall ChunkedPrefill Impact
|
||||||
|
|
||||||
|
**GPU time distribution** (test_attention_offload.py):
|
||||||
|
|
||||||
|
| Component | Time (ms) | Percentage |
|
||||||
|
|-----------|-----------|------------|
|
||||||
|
| FlashAttention | 603.2 | 74.8% |
|
||||||
|
| Triton Merge | 160.7 | 19.9% |
|
||||||
|
| Other | 42.1 | 5.3% |
|
||||||
|
| **Total** | **806.0** | **100%** |
|
||||||
|
|
||||||
|
**If using PyTorch merge** (estimated):
|
||||||
|
- Total GPU time: ~1,343 ms
|
||||||
|
- **Overall speedup with Triton**: 1.67x
|
||||||
|
|
||||||
|
### Key Files
|
||||||
|
|
||||||
|
- `nanovllm/kvcache/chunked_attention.py`: Triton kernels + merge function
|
||||||
|
|
||||||
|
## Known Issues and Fixes
|
||||||
|
|
||||||
|
### Partial Last Block Bug (FIXED ✓)
|
||||||
|
|
||||||
|
**Problem**: When prefill token count is not an exact multiple of `block_size`, decode outputs garbage.
|
||||||
|
|
||||||
|
**Root Cause**: `_chunked_decode_attention` calculated `last_block_valid_tokens` using `len(seq) - 1`, which increases during decode. But CPU blocks are fixed after prefill!
|
||||||
|
|
||||||
|
```python
|
||||||
|
# BUG: len(seq) increases each decode step
|
||||||
|
total_prefill_tokens = len(seq) - 1 # Wrong!
|
||||||
|
last_block_valid_tokens = total_prefill_tokens % block_size # Reads garbage from CPU
|
||||||
|
```
|
||||||
|
|
||||||
|
**Fix**: Cache original prefill length in `HybridKVCacheManager.get_prefill_len()`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# CORRECT: Use cached prefill length
|
||||||
|
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Fixed value
|
||||||
|
```
|
||||||
|
|
||||||
|
**Files Modified**:
|
||||||
|
- `nanovllm/kvcache/hybrid_manager.py`: Added `_prefill_len` dict and `get_prefill_len()` method
|
||||||
|
- `nanovllm/layers/attention.py`: Use `get_prefill_len()` instead of `len(seq) - 1`
|
||||||
|
|
||||||
|
### Block Size 4096 Race Condition (FIXED ✓)
|
||||||
|
|
||||||
|
**Problem**: `block_size=4096` with multiple chunks produced `index_copy_(): index out of bounds` CUDA error during Chunk 2 processing.
|
||||||
|
|
||||||
|
**Root Cause**: Race condition between default stream and compute stream. In `_prepare_chunked_offload_chunk()`, `slot_mapping` tensor was created with `non_blocking=True` H2D transfer on the default stream. However, `store_kvcache` runs on `compute_stream`. Without synchronization, `compute_stream` could use `slot_mapping` before its transfer completed, causing corrupted indices.
|
||||||
|
|
||||||
|
**Fix** (in `attention.py`):
|
||||||
|
```python
|
||||||
|
if is_chunked_offload:
|
||||||
|
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
||||||
|
if k_cache.numel() and v_cache.numel():
|
||||||
|
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||||
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Tested block sizes**: 512, 1024, 4096, 8192 - all pass.
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
| Parameter | Default | Notes |
|
| Parameter | Default | Notes |
|
||||||
|-----------|---------|-------|
|
|-----------|---------|-------|
|
||||||
| `kvcache_block_size` | 4096 | Tokens per block |
|
| `kvcache_block_size` | 1024 | Tokens per block (4096 now works after race condition fix) |
|
||||||
| `max_num_batched_tokens` | 16384 | Set = max_model_len for long context |
|
| `max_num_batched_tokens` | 16384 | Set = max_model_len for long context |
|
||||||
| `gpu_memory_utilization` | 0.9 | GPU memory fraction |
|
| `gpu_memory_utilization` | 0.9 | GPU memory fraction |
|
||||||
| `enable_cpu_offload` | False | Enable for long context |
|
| `enable_cpu_offload` | False | Enable for long context |
|
||||||
| `num_gpu_blocks` | 2 | GPU blocks for offload mode |
|
|
||||||
| `num_kv_buffers` | 4 | Ring buffer size (1-4), lower = less memory but slower decode |
|
|
||||||
| `enforce_eager` | False | Set True to disable CUDA graphs |
|
|
||||||
|
|
||||||
## Benchmarking
|
## Benchmarking
|
||||||
|
|
||||||
@@ -90,14 +455,58 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
|
|||||||
**Model Limits**:
|
**Model Limits**:
|
||||||
- Qwen3-0.6B/4B: 40960 tokens
|
- Qwen3-0.6B/4B: 40960 tokens
|
||||||
- Qwen2.5-7B-Instruct-1M: 1048576 tokens
|
- Qwen2.5-7B-Instruct-1M: 1048576 tokens
|
||||||
- Llama-3.1-8B-Instruct: 131072 tokens
|
|
||||||
- **64k on RTX 3090/4090 (24GB)**: Requires CPU offload + optimizations, see [`docs/64k_memory_analysis.md`](docs/64k_memory_analysis.md)
|
|
||||||
|
|
||||||
**Performance (Qwen3-4B, CPU Offload)**:
|
**Performance (Qwen3-0.6B)**:
|
||||||
- Prefill: ~5700-8000 tok/s (varies by context length)
|
- GPU: ~18k tok/s (prefill), ~100 tok/s (decode)
|
||||||
- Decode with CUDA Graph: ~50 tok/s (TPOT ~19ms)
|
- CPU Offload (16K): ~14k tok/s (prefill)
|
||||||
- Decode Eager Mode: ~12 tok/s (TPOT ~80ms)
|
- CPU Offload (32K): ~13k tok/s (prefill)
|
||||||
- **CUDA Graph speedup: 4x decode throughput**
|
|
||||||
|
## Performance Summary
|
||||||
|
|
||||||
|
### Completed Optimizations ✓
|
||||||
|
|
||||||
|
1. **sgDMA Integration** (2025-12-25)
|
||||||
|
- Eliminated Device→Pageable transfers
|
||||||
|
- Achieved 21-23 GB/s bandwidth (near PCIe limit)
|
||||||
|
- 15.35x speedup on memory transfers
|
||||||
|
|
||||||
|
2. **Triton Fused Merge Kernel** (2025-12-25)
|
||||||
|
- Reduced 7 PyTorch kernels → 2 Triton kernels
|
||||||
|
- 4.3x speedup on merge operations
|
||||||
|
- 1.67x overall ChunkedPrefill speedup
|
||||||
|
|
||||||
|
3. **N-way Pipeline with Dedicated Streams** (2025-12-25)
|
||||||
|
- Per-slot transfer streams for parallel H2D across slots
|
||||||
|
- Dedicated compute stream (avoids CUDA default stream implicit sync)
|
||||||
|
- N-way pipeline using all available slots (not just 2-slot double buffering)
|
||||||
|
- **2.0x improvement**: 7.2k → 14.1k tok/s (16K tokens prefill)
|
||||||
|
|
||||||
|
### Current Performance Bottlenecks
|
||||||
|
|
||||||
|
**From profiling** (`test_attention_offload.py`, 8 layers, 16K tokens):
|
||||||
|
|
||||||
|
| Component | GPU Time | Percentage | Optimization Potential |
|
||||||
|
|-----------|----------|------------|------------------------|
|
||||||
|
| FlashAttention | 603 ms | 74.8% | ⚠️ Main bottleneck |
|
||||||
|
| Triton Merge | 161 ms | 19.9% | ✓ Optimized |
|
||||||
|
| Other | 42 ms | 5.3% | Minor |
|
||||||
|
|
||||||
|
### Future Optimization Directions
|
||||||
|
|
||||||
|
1. **FlashAttention Optimization** (highest priority)
|
||||||
|
- Current: 74.8% of GPU time
|
||||||
|
- Potential: Custom FlashAttention kernel for chunked case
|
||||||
|
- Expected: 1.5-2x additional speedup
|
||||||
|
|
||||||
|
2. ~~**Pipeline Optimization**~~ ✓ COMPLETED
|
||||||
|
- ~~Better overlap between compute and memory transfer~~
|
||||||
|
- ~~Multi-stream execution~~
|
||||||
|
- See: N-way Pipeline with Dedicated Streams above
|
||||||
|
|
||||||
|
3. **Alternative to sgDMA** (lower priority, PyTorch-only)
|
||||||
|
- Reorganize cache layout: `[num_cpu_blocks, num_layers, ...]` instead of `[num_layers, num_cpu_blocks, ...]`
|
||||||
|
- Trade-off: Extensive refactoring vs minimal sgDMA approach
|
||||||
|
- Same performance as sgDMA (~24 GB/s)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|||||||
103
DEBUG_SUMMARY.md
Normal file
103
DEBUG_SUMMARY.md
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
# Chunked Prefill Bug Debug Summary
|
||||||
|
|
||||||
|
## Problem
|
||||||
|
`test_needle.py --enable-offload --input-len 8192` fails with garbage output.
|
||||||
|
|
||||||
|
The model generates completely wrong tokens instead of the expected "7492".
|
||||||
|
|
||||||
|
## Investigation Progress
|
||||||
|
|
||||||
|
### 1. Stream Synchronization Fix (Completed)
|
||||||
|
- Replaced Triton `store_kvcache` kernel with pure PyTorch operations
|
||||||
|
- Moved `store_kvcache` to `compute_stream` in chunked prefill mode
|
||||||
|
- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload
|
||||||
|
- Added sync: `default_stream.wait_stream(compute_stream)` before return
|
||||||
|
|
||||||
|
### 2. KV Cache Alignment Verification (Completed)
|
||||||
|
Created alignment tests to compare K/V tensors between torch reference and nanovllm:
|
||||||
|
|
||||||
|
**RoPE Alignment:**
|
||||||
|
- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0)
|
||||||
|
- Confirmed RoPE is NOT the cause of the bug
|
||||||
|
|
||||||
|
**K/V Cache Alignment (Chunk 0):**
|
||||||
|
- Cosine similarity: ~1.0 for all layers
|
||||||
|
- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision)
|
||||||
|
- Mean diff: < 0.001
|
||||||
|
- **Conclusion: K/V cache offload is working correctly**
|
||||||
|
|
||||||
|
### 3. Layer Output Divergence Analysis (Completed)
|
||||||
|
Created per-chunk layer output comparison:
|
||||||
|
|
||||||
|
**Chunk 0 (tokens 0-4096):**
|
||||||
|
- All layers pass with excellent cosine similarity (0.999+)
|
||||||
|
- Max diff grows in later layers but within acceptable range
|
||||||
|
|
||||||
|
**Chunk 1 (tokens 4096-8192):**
|
||||||
|
- Layers 0-19: OK (cosine ~1.0)
|
||||||
|
- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114)
|
||||||
|
- Divergence correlates with later transformer layers
|
||||||
|
|
||||||
|
### 4. Critical Discovery: Single-Chunk Offload Also Fails
|
||||||
|
**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled.
|
||||||
|
|
||||||
|
```
|
||||||
|
# Without offload: PASSES
|
||||||
|
python tests/test_needle.py --input-len 2048
|
||||||
|
# Output: "7492" (correct)
|
||||||
|
|
||||||
|
# With offload: FAILS
|
||||||
|
python tests/test_needle.py --enable-offload --input-len 2048
|
||||||
|
# Output: "The Ble White Th G Lopsiswin..." (garbage)
|
||||||
|
```
|
||||||
|
|
||||||
|
**This proves the bug is NOT in:**
|
||||||
|
- Chunked attention logic (merge_attention_outputs)
|
||||||
|
- Multi-chunk KV loading
|
||||||
|
- Ring buffer pipeline
|
||||||
|
|
||||||
|
**The bug IS in:**
|
||||||
|
- The decode path when CPU offload is enabled
|
||||||
|
- How prefilled KV is loaded/used during decode
|
||||||
|
|
||||||
|
### 5. Decode Path Analysis (In Progress)
|
||||||
|
The decode path in CPU offload mode:
|
||||||
|
1. Prefill writes KV to GPU, offloads to CPU
|
||||||
|
2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline`
|
||||||
|
3. Attend to prefilled KV + accumulated decode tokens
|
||||||
|
4. Merge results
|
||||||
|
|
||||||
|
**Observations:**
|
||||||
|
- `prefilled_blocks` set is empty after decode (should contain block IDs)
|
||||||
|
- CPU cache has valid data (reasonable mean/std values)
|
||||||
|
- Decode buffer has zeros (decode tokens not being stored correctly?)
|
||||||
|
|
||||||
|
## Current Status
|
||||||
|
|
||||||
|
### Working
|
||||||
|
- Stream synchronization fixes
|
||||||
|
- K/V cache offload to CPU (verified alignment)
|
||||||
|
- RoPE implementation
|
||||||
|
- Chunked prefill attention for first chunk
|
||||||
|
|
||||||
|
### Not Working
|
||||||
|
- Decode with CPU offload (even for single-chunk inputs)
|
||||||
|
- Multi-chunk attention (divergence in later layers for chunk 1)
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
1. Debug why `prefilled_blocks` is empty after decode
|
||||||
|
2. Check if decode path correctly loads KV from CPU
|
||||||
|
3. Verify decode buffer is being written correctly
|
||||||
|
4. Compare decode attention outputs between offload and non-offload modes
|
||||||
|
|
||||||
|
## Key Files
|
||||||
|
- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths
|
||||||
|
- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine
|
||||||
|
- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks`
|
||||||
|
- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration
|
||||||
|
|
||||||
|
## Hypothesis
|
||||||
|
The decode path fails because:
|
||||||
|
1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty
|
||||||
|
2. OR the decode attention is not correctly loading/using the prefilled KV from CPU
|
||||||
|
3. OR there's a stream synchronization issue specific to decode path
|
||||||
178
bench.py
178
bench.py
@@ -2,7 +2,6 @@ import os
|
|||||||
import time
|
import time
|
||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
from nanovllm import LLM, SamplingParams
|
from nanovllm import LLM, SamplingParams
|
||||||
from nanovllm.config import SparsePolicyType
|
|
||||||
|
|
||||||
|
|
||||||
def bench_decode(llm, num_seqs, input_len, output_len):
|
def bench_decode(llm, num_seqs, input_len, output_len):
|
||||||
@@ -24,8 +23,8 @@ def bench_decode(llm, num_seqs, input_len, output_len):
|
|||||||
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
|
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
|
||||||
|
|
||||||
|
|
||||||
def bench_prefill(llm, num_seqs, input_len, label=""):
|
def bench_prefill(llm, num_seqs, input_len):
|
||||||
"""Benchmark prefill performance. Returns throughput."""
|
"""Benchmark prefill performance"""
|
||||||
seed(0)
|
seed(0)
|
||||||
# Fixed length input, minimal output to focus on prefill
|
# Fixed length input, minimal output to focus on prefill
|
||||||
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||||
@@ -36,28 +35,7 @@ def bench_prefill(llm, num_seqs, input_len, label=""):
|
|||||||
t = time.time() - t
|
t = time.time() - t
|
||||||
total_input_tokens = num_seqs * input_len
|
total_input_tokens = num_seqs * input_len
|
||||||
throughput = total_input_tokens / t
|
throughput = total_input_tokens / t
|
||||||
label_str = f" ({label})" if label else ""
|
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||||
print(f"[Prefill{label_str}] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
|
||||||
return throughput
|
|
||||||
|
|
||||||
|
|
||||||
def create_llm(path, max_len, enable_minference=False, minference_budget=0.3,
|
|
||||||
minference_vertical=1000, minference_slash=6096,
|
|
||||||
gpu_utilization=0.8):
|
|
||||||
"""Create LLM with specified configuration."""
|
|
||||||
kwargs = {
|
|
||||||
"enforce_eager": True, # MInference uses Triton, not compatible with CUDA graphs
|
|
||||||
"max_model_len": max_len,
|
|
||||||
"max_num_batched_tokens": max_len,
|
|
||||||
"gpu_memory_utilization": gpu_utilization,
|
|
||||||
}
|
|
||||||
if enable_minference:
|
|
||||||
kwargs["sparse_policy"] = SparsePolicyType.MINFERENCE
|
|
||||||
kwargs["minference_adaptive_budget"] = minference_budget
|
|
||||||
kwargs["minference_vertical_size"] = minference_vertical
|
|
||||||
kwargs["minference_slash_size"] = minference_slash
|
|
||||||
|
|
||||||
return LLM(path, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -68,17 +46,24 @@ def main():
|
|||||||
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
||||||
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||||
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||||
parser.add_argument("--enable-minference", action="store_true", help="Enable MInference sparse prefill")
|
|
||||||
parser.add_argument("--minference-budget", type=float, default=0.3, help="MInference adaptive budget (default: 0.3, use 0 for fixed mode)")
|
|
||||||
parser.add_argument("--minference-vertical", type=int, default=1000, help="Fixed vertical_size (only used when budget=0)")
|
|
||||||
parser.add_argument("--minference-slash", type=int, default=6096, help="Fixed slash_size (only used when budget=0)")
|
|
||||||
parser.add_argument("--gpu-utilization", type=float, default=0.9, help="GPU memory utilization (default: 0.9)")
|
|
||||||
parser.add_argument("--compare", action="store_true", help="Compare baseline vs MInference (runs both)")
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
max_len = args.max_len
|
max_len = args.max_len
|
||||||
|
|
||||||
|
print(f"\n[nanovllm GPU] max_len={max_len}")
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
path,
|
||||||
|
enforce_eager=False,
|
||||||
|
max_model_len=max_len,
|
||||||
|
max_num_batched_tokens=max_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
print("\nWarming up...")
|
||||||
|
llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10))
|
||||||
|
|
||||||
# Default input lengths
|
# Default input lengths
|
||||||
prefill_input_len = args.input_len if args.input_len else max_len - 1
|
prefill_input_len = args.input_len if args.input_len else max_len - 1
|
||||||
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
|
decode_input_len = args.input_len if args.input_len else max_len - args.output_len
|
||||||
@@ -87,128 +72,17 @@ def main():
|
|||||||
run_prefill = not args.bench_decode or args.bench_all
|
run_prefill = not args.bench_decode or args.bench_all
|
||||||
run_decode = args.bench_decode or args.bench_all
|
run_decode = args.bench_decode or args.bench_all
|
||||||
|
|
||||||
# Convert budget=0 to None for fixed mode
|
if run_prefill:
|
||||||
minference_budget = args.minference_budget if args.minference_budget > 0 else None
|
print("\n" + "=" * 60)
|
||||||
|
print("Prefill Benchmark (nanovllm GPU)")
|
||||||
|
print("=" * 60)
|
||||||
|
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
|
||||||
|
|
||||||
if args.compare:
|
if run_decode:
|
||||||
# Compare baseline vs MInference using subprocesses to avoid NCCL issues
|
print("\n" + "=" * 60)
|
||||||
import subprocess
|
print("Decode Benchmark (nanovllm GPU)")
|
||||||
import sys
|
print("=" * 60)
|
||||||
|
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Baseline vs MInference Comparison")
|
|
||||||
print(f"Input length: {prefill_input_len} tokens")
|
|
||||||
if minference_budget is not None:
|
|
||||||
print(f"MInference mode: adaptive (budget={minference_budget}, {minference_budget*100:.0f}% compute)")
|
|
||||||
else:
|
|
||||||
print(f"MInference mode: fixed (vertical={args.minference_vertical}, slash={args.minference_slash})")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
# Get PYTHONPATH for subprocess
|
|
||||||
pythonpath = os.environ.get("PYTHONPATH", "")
|
|
||||||
|
|
||||||
# Run baseline in subprocess
|
|
||||||
print(f"\n[1/2] Running baseline (FULL attention)...")
|
|
||||||
cmd_baseline = [
|
|
||||||
sys.executable, __file__,
|
|
||||||
"--input-len", str(prefill_input_len),
|
|
||||||
"--max-len", str(max_len),
|
|
||||||
"--gpu-utilization", str(args.gpu_utilization),
|
|
||||||
]
|
|
||||||
env = os.environ.copy()
|
|
||||||
result = subprocess.run(cmd_baseline, capture_output=True, text=True, env=env)
|
|
||||||
print(result.stdout)
|
|
||||||
if result.returncode != 0:
|
|
||||||
print(f"Error: {result.stderr}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Parse baseline throughput
|
|
||||||
baseline_throughput = None
|
|
||||||
for line in result.stdout.split('\n'):
|
|
||||||
if "Throughput:" in line and "tok/s" in line:
|
|
||||||
# Extract throughput value
|
|
||||||
import re
|
|
||||||
match = re.search(r'Throughput:\s*([\d.]+)tok/s', line)
|
|
||||||
if match:
|
|
||||||
baseline_throughput = float(match.group(1))
|
|
||||||
|
|
||||||
# Run MInference in subprocess
|
|
||||||
if minference_budget is not None:
|
|
||||||
print(f"\n[2/2] Running MInference (budget={minference_budget})...")
|
|
||||||
else:
|
|
||||||
print(f"\n[2/2] Running MInference (vertical={args.minference_vertical}, slash={args.minference_slash})...")
|
|
||||||
cmd_minference = [
|
|
||||||
sys.executable, __file__,
|
|
||||||
"--input-len", str(prefill_input_len),
|
|
||||||
"--max-len", str(max_len),
|
|
||||||
"--gpu-utilization", str(args.gpu_utilization),
|
|
||||||
"--enable-minference",
|
|
||||||
"--minference-budget", str(args.minference_budget),
|
|
||||||
"--minference-vertical", str(args.minference_vertical),
|
|
||||||
"--minference-slash", str(args.minference_slash),
|
|
||||||
]
|
|
||||||
result = subprocess.run(cmd_minference, capture_output=True, text=True, env=env)
|
|
||||||
print(result.stdout)
|
|
||||||
if result.returncode != 0:
|
|
||||||
print(f"Error: {result.stderr}")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Parse MInference throughput
|
|
||||||
minference_throughput = None
|
|
||||||
for line in result.stdout.split('\n'):
|
|
||||||
if "Throughput:" in line and "tok/s" in line:
|
|
||||||
import re
|
|
||||||
match = re.search(r'Throughput:\s*([\d.]+)tok/s', line)
|
|
||||||
if match:
|
|
||||||
minference_throughput = float(match.group(1))
|
|
||||||
|
|
||||||
# Comparison
|
|
||||||
if baseline_throughput and minference_throughput:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Results Summary")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Baseline: {baseline_throughput:,.0f} tok/s")
|
|
||||||
print(f"MInference: {minference_throughput:,.0f} tok/s")
|
|
||||||
speedup = minference_throughput / baseline_throughput
|
|
||||||
if speedup >= 1.0:
|
|
||||||
print(f"Speedup: {speedup:.2f}x faster")
|
|
||||||
else:
|
|
||||||
print(f"Slowdown: {1/speedup:.2f}x slower")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
else:
|
|
||||||
print("Failed to parse throughput values")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Single run mode
|
|
||||||
mode = "MInference" if args.enable_minference else "GPU"
|
|
||||||
print(f"\n[nanovllm {mode}] max_len={max_len}")
|
|
||||||
if args.enable_minference:
|
|
||||||
if minference_budget is not None:
|
|
||||||
print(f"MInference mode: adaptive (budget={minference_budget})")
|
|
||||||
else:
|
|
||||||
print(f"MInference mode: fixed (vertical={args.minference_vertical}, slash={args.minference_slash})")
|
|
||||||
|
|
||||||
llm = create_llm(path, max_len, enable_minference=args.enable_minference,
|
|
||||||
minference_budget=minference_budget,
|
|
||||||
minference_vertical=args.minference_vertical,
|
|
||||||
minference_slash=args.minference_slash,
|
|
||||||
gpu_utilization=args.gpu_utilization)
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
print("\nWarming up...")
|
|
||||||
llm.generate(["Benchmark warmup: "], SamplingParams(max_tokens=10))
|
|
||||||
|
|
||||||
if run_prefill:
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print(f"Prefill Benchmark (nanovllm {mode})")
|
|
||||||
print("=" * 60)
|
|
||||||
bench_prefill(llm, num_seqs=1, input_len=prefill_input_len)
|
|
||||||
|
|
||||||
if run_decode:
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print(f"Decode Benchmark (nanovllm {mode})")
|
|
||||||
print("=" * 60)
|
|
||||||
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ["VLLM_USE_V1"] = "1"
|
os.environ["VLLM_USE_V1"] = "1"
|
||||||
import time
|
import time
|
||||||
from random import randint, seed
|
from random import randint, seed
|
||||||
@@ -9,12 +8,8 @@ from vllm import LLM, SamplingParams
|
|||||||
def bench_decode(llm, num_seqs, input_len, output_len):
|
def bench_decode(llm, num_seqs, input_len, output_len):
|
||||||
"""Benchmark decode performance"""
|
"""Benchmark decode performance"""
|
||||||
seed(0)
|
seed(0)
|
||||||
prompt_token_ids = [
|
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||||
[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)
|
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=output_len)
|
||||||
]
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.6, ignore_eos=True, max_tokens=output_len
|
|
||||||
)
|
|
||||||
prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
||||||
|
|
||||||
t = time.time()
|
t = time.time()
|
||||||
@@ -26,21 +21,15 @@ def bench_decode(llm, num_seqs, input_len, output_len):
|
|||||||
decode_tokens = num_seqs * output_len
|
decode_tokens = num_seqs * output_len
|
||||||
decode_throughput = decode_tokens / t
|
decode_throughput = decode_tokens / t
|
||||||
|
|
||||||
print(
|
print(f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s")
|
||||||
f"[Decode] Input: {num_seqs}x{input_len}tok, Output: {decode_tokens}tok, Time: {t:.2f}s"
|
print(f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)")
|
||||||
)
|
|
||||||
print(
|
|
||||||
f" Throughput: {decode_throughput:.2f} tok/s (includes prefill overhead)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def bench_prefill(llm, num_seqs, input_len):
|
def bench_prefill(llm, num_seqs, input_len):
|
||||||
"""Benchmark prefill performance"""
|
"""Benchmark prefill performance"""
|
||||||
seed(0)
|
seed(0)
|
||||||
# Fixed length input, minimal output to focus on prefill
|
# Fixed length input, minimal output to focus on prefill
|
||||||
prompt_token_ids = [
|
prompt_token_ids = [[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)]
|
||||||
[randint(0, 10000) for _ in range(input_len)] for _ in range(num_seqs)
|
|
||||||
]
|
|
||||||
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
|
||||||
prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids]
|
||||||
|
|
||||||
@@ -49,39 +38,17 @@ def bench_prefill(llm, num_seqs, input_len):
|
|||||||
t = time.time() - t
|
t = time.time() - t
|
||||||
total_input_tokens = num_seqs * input_len
|
total_input_tokens = num_seqs * input_len
|
||||||
throughput = total_input_tokens / t
|
throughput = total_input_tokens / t
|
||||||
print(
|
print(f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s")
|
||||||
f"[Prefill] Input: {total_input_tokens}tok ({num_seqs}x{input_len}), Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser(description="Benchmark vLLM performance (for comparison)")
|
||||||
parser = argparse.ArgumentParser(
|
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
||||||
description="Benchmark vLLM performance (for comparison)"
|
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
||||||
)
|
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
||||||
parser.add_argument(
|
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||||
"--input-len", type=int, default=None, help="Input length in tokens"
|
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--output-len",
|
|
||||||
type=int,
|
|
||||||
default=64,
|
|
||||||
help="Output length for decode benchmark (default: 64)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-len", type=int, default=32 * 1024, help="Max model length (default: 32K)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--bench-decode",
|
|
||||||
action="store_true",
|
|
||||||
help="Run decode benchmark (default: prefill only)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--bench-all",
|
|
||||||
action="store_true",
|
|
||||||
help="Run both prefill and decode benchmarks",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
||||||
@@ -94,7 +61,7 @@ def main():
|
|||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
max_model_len=max_len,
|
max_model_len=max_len,
|
||||||
max_num_seqs=128,
|
max_num_seqs=128,
|
||||||
gpu_memory_utilization=0.7,
|
gpu_memory_utilization=0.9,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
@@ -119,9 +86,7 @@ def main():
|
|||||||
print("\n" + "=" * 60)
|
print("\n" + "=" * 60)
|
||||||
print("Decode Benchmark (vLLM)")
|
print("Decode Benchmark (vLLM)")
|
||||||
print("=" * 60)
|
print("=" * 60)
|
||||||
bench_decode(
|
bench_decode(llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len)
|
||||||
llm, num_seqs=1, input_len=decode_input_len, output_len=args.output_len
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -1,131 +0,0 @@
|
|||||||
# 64k 推理内存分析
|
|
||||||
|
|
||||||
本文档分析 Llama 3.1 8B 模型在 64k 长度推理时的内存占用,以及 RTX 3090 (24GB) 上的 OOM 问题。
|
|
||||||
|
|
||||||
## 模型配置
|
|
||||||
|
|
||||||
```python
|
|
||||||
hidden_size = 4096
|
|
||||||
intermediate_size = 14336
|
|
||||||
num_layers = 32
|
|
||||||
num_heads = 32
|
|
||||||
num_kv_heads = 8
|
|
||||||
head_dim = 128
|
|
||||||
seq_len = 65536
|
|
||||||
dtype = bfloat16 (2 bytes)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 理论内存占用
|
|
||||||
|
|
||||||
### GPU Only 模式
|
|
||||||
|
|
||||||
| 组件 | 计算公式 | 内存占用 |
|
|
||||||
|------|----------|----------|
|
|
||||||
| 模型权重 | 8.03B × 2 bytes | **16.06 GB** |
|
|
||||||
| KV Cache | 32 × 65536 × 8 × 128 × 2 × 2 | **8.19 GB** |
|
|
||||||
| Prefill 激活值峰值 | max(QKV, MLP) | **~2 GB** |
|
|
||||||
| **总计** | | **~26 GB** |
|
|
||||||
|
|
||||||
**结论**:GPU only 模式需要 ~26 GB,**RTX 3090 (24GB) 无法运行**。
|
|
||||||
|
|
||||||
### CPU Offload 模式
|
|
||||||
|
|
||||||
| 组件 | 计算公式 | 内存占用 |
|
|
||||||
|------|----------|----------|
|
|
||||||
| 模型权重 | 8.03B × 2 bytes | **16.06 GB** |
|
|
||||||
| Ring buffer | num_kv_buffers × seq_len × 128 KB/token | 258-1034 MB |
|
|
||||||
| GPU KV blocks | num_gpu_blocks × block_size × 128 KB/token | 256 MB (2 blocks) |
|
|
||||||
| Per-layer decode buffer | 32 layers × 缓冲 | 128 MB |
|
|
||||||
| 激活值峰值 (chunked) | chunk_size × hidden_size × 2 | ~50 MB |
|
|
||||||
| PyTorch 开销 | CUDA 上下文 + 碎片 | ~5-6 GB |
|
|
||||||
| **理论小计** | | **~17.5 GB** |
|
|
||||||
| **实际需求** | | **~23 GB** |
|
|
||||||
|
|
||||||
**配置参数**:
|
|
||||||
- `num_kv_buffers`: Ring buffer 大小 (1-4),默认 4
|
|
||||||
- `num_gpu_blocks`: GPU 上的 KV cache block 数量
|
|
||||||
- `block_size`: 每个 block 的 token 数
|
|
||||||
|
|
||||||
## OOM 问题分析
|
|
||||||
|
|
||||||
### 实际观测(RTX 3090, num_kv_buffers=1)
|
|
||||||
|
|
||||||
```
|
|
||||||
PyTorch allocated: 22.49 GB
|
|
||||||
PyTorch reserved: 429 MB
|
|
||||||
Free: 306 MB
|
|
||||||
Total available: 735 MB
|
|
||||||
Failed to allocate: 508 MB (torch.cat)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 内存碎片来源
|
|
||||||
|
|
||||||
| 来源 | 说明 | 影响 |
|
|
||||||
|------|------|------|
|
|
||||||
| Binned 分配器 | PyTorch 使用固定大小的内存池 | 中等 |
|
|
||||||
| torch.compile 缓存 | 编译后的 kernel 代码和常量 | 高 (~2-3 GB) |
|
|
||||||
| 频繁分配/释放 | chunked 处理中每个 chunk 的创建销毁 | 高 |
|
|
||||||
| 不同大小张量 | (128,4096), (65536,6144) 等 | 中等 |
|
|
||||||
|
|
||||||
### torch.cat 内存需求
|
|
||||||
|
|
||||||
Chunked MLP 处理(chunk_size=128):
|
|
||||||
```
|
|
||||||
65536 / 128 = 512 chunks
|
|
||||||
每个 chunk 输出: (128, 4096) × 2 bytes = 1 MB
|
|
||||||
torch.cat 拼接需要: (65536, 4096) × 2 bytes = 508 MB (连续)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 已尝试的优化
|
|
||||||
|
|
||||||
| 优化项 | 效果 |
|
|
||||||
|--------|------|
|
|
||||||
| 移除 `@torch.compile` | PyTorch: 23.13 → 22.80 GB (-300 MB) |
|
|
||||||
| 减少 `num_kv_buffers` (4→1) | Ring buffer: 1034 → 258 MB (-776 MB) |
|
|
||||||
| Chunked QKV/MLP/LayerNorm | 峰值激活: ~2 GB → ~50 MB |
|
|
||||||
| 降低 GPU 利用率 (0.9→0.75) | 无明显效果 |
|
|
||||||
| 减小 chunk_size (4096→128) | 峰值降低,但 torch.cat 需要连续内存 |
|
|
||||||
|
|
||||||
### 最终状态
|
|
||||||
|
|
||||||
```
|
|
||||||
理论需求: ~17.5 GB
|
|
||||||
实际分配: 22.49 GB
|
|
||||||
剩余空间: 735 MB (306 MB + 429 MB reserved)
|
|
||||||
分配失败: 508 MB (torch.cat 需要连续内存)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 结论
|
|
||||||
|
|
||||||
### 根本原因
|
|
||||||
|
|
||||||
**不是绝对内存不足,而是内存碎片导致的分配失败**。
|
|
||||||
|
|
||||||
理论需求 17.5 GB < 24 GB,但由于:
|
|
||||||
- PyTorch 开销(CUDA 上下文、碎片):~5-6 GB
|
|
||||||
- torch.compile 缓存:~2-3 GB(已移除)
|
|
||||||
- 内存碎片导致无法分配 508 MB 连续块
|
|
||||||
|
|
||||||
### 硬件限制
|
|
||||||
|
|
||||||
| GPU | 显存 | 64k GPU Only | 64k Offload |
|
|
||||||
|-----|------|--------------|--------------|
|
|
||||||
| RTX 3090 | 24 GB | ❌ | ⚠️ 碎片问题 |
|
|
||||||
| RTX 4090 | 24 GB | ❌ | ⚠️ 碎片问题 |
|
|
||||||
| A100 | 40 GB | ✅ | ✅ |
|
|
||||||
| A100 | 80 GB | ✅ | ✅ |
|
|
||||||
|
|
||||||
### 建议
|
|
||||||
|
|
||||||
1. **64k 推理建议使用 40GB+ 显存的 GPU**
|
|
||||||
2. RTX 3090/4090 适合 32k 或更短的场景
|
|
||||||
3. 如必须在 24GB GPU 上运行 64k:
|
|
||||||
- 使用 RAPIDS RMM 分配器
|
|
||||||
- 预分配 torch.cat 需要的内存
|
|
||||||
- 或使用流式处理避免 torch.cat
|
|
||||||
|
|
||||||
## 参考
|
|
||||||
|
|
||||||
- [PyTorch 内存管理文档](https://docs.pytorch.org/docs/stable/generated/torch.cuda.memory.memory_stats.html)
|
|
||||||
- [PyTorch 内存碎片讨论](https://discuss.pytorch.org/t/how-to-reduce-memory-fragmentation-when-enable-expandable-segments/221805)
|
|
||||||
- [STWeaver - 减少 79% 内存碎片](https://arxiv.org/html/2507.16274v1)
|
|
||||||
@@ -1,161 +0,0 @@
|
|||||||
# 64K Prefill MLP Activation OOM Issue
|
|
||||||
|
|
||||||
## Problem Summary
|
|
||||||
|
|
||||||
When running RULER benchmark with 64K context length using CPU offload mode, OOM occurs during MLP forward pass in `run_layerwise_offload_prefill`. The KV cache is successfully offloaded to CPU, but MLP intermediate activations exceed available GPU memory.
|
|
||||||
|
|
||||||
## Environment
|
|
||||||
|
|
||||||
- GPU: RTX 3090 (24GB)
|
|
||||||
- Model: LLaMA 3.1 8B
|
|
||||||
- Sequence Length: 65536 tokens
|
|
||||||
- Mode: `enable_cpu_offload=True`, `num_gpu_blocks=2`
|
|
||||||
|
|
||||||
## Error Message
|
|
||||||
|
|
||||||
```
|
|
||||||
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.47 GiB.
|
|
||||||
GPU 0 has a total capacity of 23.57 GiB of which 2.66 GiB is free.
|
|
||||||
Including non-PyTorch memory, this process has 20.88 GiB memory in use.
|
|
||||||
Of the allocated memory 20.51 GiB is allocated by PyTorch, and 32.26 MiB
|
|
||||||
is reserved by PyTorch but unallocated.
|
|
||||||
```
|
|
||||||
|
|
||||||
## Stack Trace
|
|
||||||
|
|
||||||
```
|
|
||||||
File "nanovllm/engine/model_runner.py", line 843, in run_layerwise_offload_prefill
|
|
||||||
hidden_states = layer.mlp(hidden_states)
|
|
||||||
File "nanovllm/models/llama.py", line 103, in forward
|
|
||||||
gate_up = self.gate_up_proj(x)
|
|
||||||
File "nanovllm/layers/linear.py", line 73, in forward
|
|
||||||
return F.linear(x, self.weight, self.bias)
|
|
||||||
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 3.47 GiB.
|
|
||||||
```
|
|
||||||
|
|
||||||
## Root Cause Analysis
|
|
||||||
|
|
||||||
### Memory Breakdown
|
|
||||||
|
|
||||||
| Component | Calculation | Size |
|
|
||||||
|-----------|-------------|------|
|
|
||||||
| Model weights (BF16) | 8B params × 2 bytes | ~16 GB |
|
|
||||||
| GPU KV cache | 2 blocks × 1024 tokens × 8KB/token | ~16 MB |
|
|
||||||
| **Remaining for activations** | 24 - 16 - overhead | **~6-7 GB** |
|
|
||||||
|
|
||||||
### MLP Activation Memory (per layer)
|
|
||||||
|
|
||||||
For LLaMA 3.1 8B with `hidden_size=4096`, `intermediate_size=14336`:
|
|
||||||
|
|
||||||
| Tensor | Shape | Size (BF16) |
|
|
||||||
|--------|-------|-------------|
|
|
||||||
| MLP input | [65536, 4096] | 512 MB |
|
|
||||||
| gate_up output | [65536, 28672] | **3.47 GB** |
|
|
||||||
| down_proj input | [65536, 14336] | 1.75 GB |
|
|
||||||
| MLP output | [65536, 4096] | 512 MB |
|
|
||||||
|
|
||||||
**Peak MLP memory**: ~3.5-4 GB for intermediate tensors
|
|
||||||
|
|
||||||
### Why OOM Occurs
|
|
||||||
|
|
||||||
1. Model weights consume ~16 GB (loaded on GPU for layer-wise processing)
|
|
||||||
2. Available memory: ~7 GB
|
|
||||||
3. MLP `gate_up_proj` output: 3.47 GB
|
|
||||||
4. Additional tensors (input, gradients, etc.): ~1-2 GB
|
|
||||||
5. **Total required > Available** → OOM
|
|
||||||
|
|
||||||
## Code Location
|
|
||||||
|
|
||||||
The issue is in `nanovllm/engine/model_runner.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Line 843 in run_layerwise_offload_prefill
|
|
||||||
hidden_states = layer.mlp(hidden_states) # <-- OOM here
|
|
||||||
```
|
|
||||||
|
|
||||||
The entire sequence (65536 tokens) is passed through MLP in one shot.
|
|
||||||
|
|
||||||
## Current Configuration
|
|
||||||
|
|
||||||
From `model_wrappers.py` (RULER integration):
|
|
||||||
|
|
||||||
```python
|
|
||||||
llm_kwargs = {
|
|
||||||
"max_model_len": max_model_len, # 128 * 1024
|
|
||||||
"max_num_batched_tokens": max_model_len, # Same as max_model_len
|
|
||||||
"enable_cpu_offload": True,
|
|
||||||
"num_gpu_blocks": 2,
|
|
||||||
...
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Setting `max_num_batched_tokens = max_model_len` causes nanovllm to process all tokens at once.
|
|
||||||
|
|
||||||
## Potential Solutions
|
|
||||||
|
|
||||||
### Option 1: Chunked MLP Processing
|
|
||||||
|
|
||||||
Modify `run_layerwise_offload_prefill` to process MLP in chunks:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Instead of:
|
|
||||||
hidden_states = layer.mlp(hidden_states)
|
|
||||||
|
|
||||||
# Do:
|
|
||||||
chunk_size = 8192 # Process 8K tokens at a time
|
|
||||||
chunks = hidden_states.split(chunk_size, dim=0)
|
|
||||||
outputs = []
|
|
||||||
for chunk in chunks:
|
|
||||||
outputs.append(layer.mlp(chunk))
|
|
||||||
hidden_states = torch.cat(outputs, dim=0)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 2: Activation Checkpointing
|
|
||||||
|
|
||||||
Use gradient checkpointing to recompute activations instead of storing them:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from torch.utils.checkpoint import checkpoint
|
|
||||||
hidden_states = checkpoint(layer.mlp, hidden_states, use_reentrant=False)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 3: Reduce Chunk Size via Config
|
|
||||||
|
|
||||||
Add a new config parameter `prefill_chunk_size` to control how many tokens are processed per forward pass.
|
|
||||||
|
|
||||||
## Memory Estimation Formula
|
|
||||||
|
|
||||||
For a given sequence length `S` and model config:
|
|
||||||
|
|
||||||
```
|
|
||||||
MLP_peak_memory = S × intermediate_size × 2 × 2 bytes
|
|
||||||
= S × 14336 × 4 bytes
|
|
||||||
|
|
||||||
For S = 65536:
|
|
||||||
MLP_peak = 65536 × 14336 × 4 = 3.76 GB
|
|
||||||
```
|
|
||||||
|
|
||||||
Maximum safe sequence length for RTX 3090 (24GB):
|
|
||||||
```
|
|
||||||
S_max = available_memory / (intermediate_size × 4)
|
|
||||||
= 6GB / (14336 × 4)
|
|
||||||
≈ 100K tokens (theoretical)
|
|
||||||
≈ 8-16K tokens (practical, with safety margin)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Reproduction Steps
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd /home/zijie/Code/COMPASS/eval/RULER/scripts
|
|
||||||
|
|
||||||
# Set SEQ_LENGTHS to 65536 in config_models.sh
|
|
||||||
# Then run:
|
|
||||||
./run.sh llama3.1-8b-nanovllm synthetic --metric full --task niah_single_1
|
|
||||||
```
|
|
||||||
|
|
||||||
## Related Files
|
|
||||||
|
|
||||||
- `nanovllm/engine/model_runner.py`: `run_layerwise_offload_prefill()` (line 751+)
|
|
||||||
- `nanovllm/models/llama.py`: `LlamaMLP.forward()` (line 103)
|
|
||||||
- `nanovllm/config.py`: Config parameters
|
|
||||||
- RULER integration: `eval/RULER/scripts/pred/model_wrappers.py`
|
|
||||||
@@ -1,189 +0,0 @@
|
|||||||
# Architecture Guide
|
|
||||||
|
|
||||||
This document describes the core architecture and layer-wise CPU offload system of nano-vLLM.
|
|
||||||
|
|
||||||
## Core Components
|
|
||||||
|
|
||||||
| Component | File | Purpose |
|
|
||||||
|-----------|------|---------|
|
|
||||||
| **LLMEngine** | `llm_engine.py` | Main entry, runs prefill-decode loop |
|
|
||||||
| **ModelRunner** | `model_runner.py` | Loads weights, allocates KV cache, CUDA graphs, layer-wise offload |
|
|
||||||
| **Scheduler** | `scheduler.py` | Two-phase scheduling (prefill → decode) |
|
|
||||||
| **BlockManager** | `block_manager.py` | Paged attention with prefix caching (xxhash), default block size 4096 |
|
|
||||||
| **Attention** | `layers/attention.py` | FlashAttention for standard inference |
|
|
||||||
|
|
||||||
## Layer-wise CPU Offload System
|
|
||||||
|
|
||||||
### Design Philosophy
|
|
||||||
|
|
||||||
Unlike chunked prefill (which processes chunks across all layers), **layer-wise offload** processes the entire sequence through one layer at a time:
|
|
||||||
|
|
||||||
```
|
|
||||||
Layer 0: [full sequence] → compute → offload K,V to CPU
|
|
||||||
Layer 1: [full sequence] → compute → offload K,V to CPU
|
|
||||||
...
|
|
||||||
Layer N: [full sequence] → compute → offload K,V to CPU
|
|
||||||
```
|
|
||||||
|
|
||||||
**Benefits**:
|
|
||||||
- Supports MInference sparse attention (requires full KV access per layer)
|
|
||||||
- Simpler memory management (one layer's KV in GPU at a time)
|
|
||||||
- Peak GPU memory = one layer's KV cache + attention workspace
|
|
||||||
|
|
||||||
### Key Files
|
|
||||||
|
|
||||||
| File | Purpose |
|
|
||||||
|------|---------|
|
|
||||||
| `nanovllm/engine/model_runner.py` | Main implementation (`run_layerwise_offload_prefill`, `run_layerwise_offload_decode`) |
|
|
||||||
| `nanovllm/kvcache/hybrid_manager.py` | CPU block management helpers |
|
|
||||||
| `nanovllm/kvcache/offload_engine.py` | CPU/GPU cache storage, ring buffer, async transfers |
|
|
||||||
|
|
||||||
### Memory Layout
|
|
||||||
|
|
||||||
**CPU Cache** (pinned memory):
|
|
||||||
```python
|
|
||||||
k_cache_cpu: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]
|
|
||||||
v_cache_cpu: [num_layers, num_cpu_blocks, block_size, kv_heads, head_dim]
|
|
||||||
```
|
|
||||||
|
|
||||||
**GPU Ring Buffer** (for decode H2D pipeline):
|
|
||||||
```python
|
|
||||||
layer_k_cache: [num_kv_buffers, max_seq_len, kv_heads, head_dim]
|
|
||||||
layer_v_cache: [num_kv_buffers, max_seq_len, kv_heads, head_dim]
|
|
||||||
```
|
|
||||||
|
|
||||||
**Per-layer KV size** (Qwen3-4B: 8 kv_heads × 128 head_dim × 2 bytes × 2 for K+V = 4KB/token):
|
|
||||||
|
|
||||||
| Context Length | KV per Layer |
|
|
||||||
|----------------|--------------|
|
|
||||||
| 128K tokens | 512 MB |
|
|
||||||
| 256K tokens | 1 GB |
|
|
||||||
| 512K tokens | 2 GB |
|
|
||||||
| 1M tokens | 4 GB |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Prefill Flow
|
|
||||||
|
|
||||||
```python
|
|
||||||
def run_layerwise_offload_prefill(self, seqs: list[Sequence]) -> list[int]:
|
|
||||||
# 1. Embedding
|
|
||||||
hidden_states = self.model.model.embed_tokens(input_ids)
|
|
||||||
|
|
||||||
# 2. Process each layer
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
# QKV projection + norms + RoPE
|
|
||||||
q = apply_rotary_pos_emb(q_proj(hidden_states), cos, sin)
|
|
||||||
k = apply_rotary_pos_emb(k_proj(hidden_states), cos, sin)
|
|
||||||
v = v_proj(hidden_states)
|
|
||||||
|
|
||||||
# Full FlashAttention (entire sequence)
|
|
||||||
attn_out = flash_attn_varlen_func(q, k, v, cu_seqlens, max_seqlen, causal=True)
|
|
||||||
|
|
||||||
# MLP
|
|
||||||
hidden_states = mlp(attn_out + residual)
|
|
||||||
|
|
||||||
# Synchronous offload to CPU (CRITICAL: must be sync to avoid memory reuse bugs)
|
|
||||||
self._offload_layer_kv_to_cpu_sync(layer_id, k, v, cpu_block_ids, total_tokens)
|
|
||||||
|
|
||||||
# 3. Final norm + sampling
|
|
||||||
return sampled_tokens
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Decode Flow
|
|
||||||
|
|
||||||
```python
|
|
||||||
def run_layerwise_offload_decode(self, seqs: list[Sequence]) -> list[int]:
|
|
||||||
# Ring buffer pipeline: preload first N layers
|
|
||||||
for i in range(num_buffers):
|
|
||||||
offload_engine.load_layer_kv_to_buffer(i, i, cpu_block_table, valid_tokens)
|
|
||||||
|
|
||||||
# For each layer:
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
current_buffer = layer_id % num_buffers
|
|
||||||
|
|
||||||
# 1. Wait for buffer load to complete
|
|
||||||
offload_engine.wait_buffer_load(current_buffer)
|
|
||||||
|
|
||||||
# 2. Get prefilled KV from ring buffer
|
|
||||||
k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens)
|
|
||||||
|
|
||||||
# 3. Compute new Q,K,V for current token
|
|
||||||
q_new = apply_rotary_pos_emb(q_proj(hidden_states), cos, sin)
|
|
||||||
k_new = apply_rotary_pos_emb(k_proj(hidden_states), cos, sin)
|
|
||||||
v_new = v_proj(hidden_states)
|
|
||||||
|
|
||||||
# 4. Concatenate and compute attention
|
|
||||||
k_full = torch.cat([k_prefill, k_new], dim=0)
|
|
||||||
v_full = torch.cat([v_prefill, v_new], dim=0)
|
|
||||||
attn_out = flash_attn_varlen_func(q_new, k_full, v_full, ..., causal=False)
|
|
||||||
# Note: causal=False because single query token should attend to ALL keys
|
|
||||||
|
|
||||||
# 5. Mark buffer done, start loading next layer
|
|
||||||
offload_engine.record_buffer_compute_done(current_buffer)
|
|
||||||
if layer_id + num_buffers < num_layers:
|
|
||||||
offload_engine.load_layer_kv_to_buffer(current_buffer, layer_id + num_buffers, ...)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Critical Implementation Details
|
|
||||||
|
|
||||||
### 1. Synchronous Offload Required
|
|
||||||
|
|
||||||
Async offload with `non_blocking=True` causes memory reuse bugs:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# BUG: PyTorch may reuse k,v GPU memory before async copy completes
|
|
||||||
offload_engine.k_cache_cpu[layer_id, block_id].copy_(k[start:end], non_blocking=True)
|
|
||||||
|
|
||||||
# CORRECT: Synchronous copy ensures data integrity
|
|
||||||
offload_engine.k_cache_cpu[layer_id, block_id, :size].copy_(k[start:end]) # sync
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Decode Attention: causal=False
|
|
||||||
|
|
||||||
During decode, the single query token must attend to ALL keys (not just preceding ones):
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Prefill: causal=True (each token only attends to previous tokens)
|
|
||||||
attn_out = flash_attn_varlen_func(..., causal=True)
|
|
||||||
|
|
||||||
# Decode: causal=False (query at position N attends to all N-1 prefill + itself)
|
|
||||||
attn_out = flash_attn_varlen_func(..., causal=False)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Ring Buffer Synchronization
|
|
||||||
|
|
||||||
The ring buffer pipeline requires careful ordering:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# CORRECT order:
|
|
||||||
offload_engine.store_decode_kv(layer_id, pos, k_new, v_new) # Store new KV
|
|
||||||
offload_engine.record_buffer_compute_done(current_buffer) # Mark done FIRST
|
|
||||||
offload_engine.load_layer_kv_to_buffer(...) # THEN start next load
|
|
||||||
|
|
||||||
# BUG: Starting load before marking done causes race condition
|
|
||||||
offload_engine.load_layer_kv_to_buffer(...) # WRONG: buffer still in use!
|
|
||||||
offload_engine.record_buffer_compute_done(current_buffer)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Helper Methods in HybridKVCacheManager
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Get all CPU blocks for a sequence
|
|
||||||
cpu_blocks = manager.get_all_cpu_blocks(seq) # List[int]
|
|
||||||
|
|
||||||
# Get only prefilled (offloaded) CPU blocks
|
|
||||||
prefilled_blocks = manager.get_prefilled_cpu_blocks(seq) # List[int]
|
|
||||||
|
|
||||||
# Get cached prefill length (doesn't change during decode)
|
|
||||||
prefill_len = manager.get_prefill_len(seq) # int
|
|
||||||
|
|
||||||
# Get decode start position
|
|
||||||
decode_pos = manager.get_decode_start_pos(seq) # int
|
|
||||||
```
|
|
||||||
@@ -1,191 +0,0 @@
|
|||||||
# Block-Sparse-Attention Library Reference
|
|
||||||
|
|
||||||
MIT Han Lab 的块稀疏注意力内核库,基于 FlashAttention 2.4.2 修改,支持多种稀疏注意力模式。
|
|
||||||
|
|
||||||
## 库信息
|
|
||||||
|
|
||||||
- **来源**: [MIT-Han-Lab/Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention)
|
|
||||||
- **本地路径**: `3rdparty/Block-Sparse-Attention` (submodule, branch: `tzj/minference`)
|
|
||||||
- **基于**: FlashAttention 2.4.2
|
|
||||||
- **安装位置**: `site-packages/block_sparse_attn`
|
|
||||||
|
|
||||||
## 支持的稀疏模式
|
|
||||||
|
|
||||||
### 1. Dense Attention
|
|
||||||
计算完整注意力矩阵,无稀疏化。
|
|
||||||
|
|
||||||
### 2. Token Streaming (token granularity)
|
|
||||||
固定数量的 sink tokens + local tokens,参考 [StreamingLLM](https://arxiv.org/abs/2309.17453)。
|
|
||||||
|
|
||||||
**适用场景**: 需要精确保留部分关键 token 的长上下文推理
|
|
||||||
|
|
||||||
### 3. Block Streaming (block granularity)
|
|
||||||
Block 粒度的 streaming attention,block_size = 128。
|
|
||||||
|
|
||||||
**适用场景**: 长序列推理,牺牲少量精度换取更大加速
|
|
||||||
|
|
||||||
### 4. Block Sparse
|
|
||||||
基于自定义 block mask 的稀疏注意力。
|
|
||||||
|
|
||||||
**适用场景**: 已知特定 attention 模式的工作负载
|
|
||||||
|
|
||||||
### 混合模式
|
|
||||||
|
|
||||||
**关键特性**: 支持不同 head 使用不同稀疏模式
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 8 个 heads 的混合配置示例
|
|
||||||
head_mask_type = [1, 1, 0, 0, 0, -1, 0, -1]
|
|
||||||
# 含义:
|
|
||||||
# - head 0,1: blocksparse (使用 basemask[0])
|
|
||||||
# - head 2-4,6: dense
|
|
||||||
# - head 5,7: streaming
|
|
||||||
```
|
|
||||||
|
|
||||||
**Mask 类型编码**:
|
|
||||||
- `0` = Dense attention
|
|
||||||
- `-1` = Streaming attention
|
|
||||||
- `1, 2, ...` = Block sparse (使用 basemask[mask_type - 1])
|
|
||||||
|
|
||||||
## API 参考
|
|
||||||
|
|
||||||
### `block_sparse_attn_func`
|
|
||||||
|
|
||||||
通用块稀疏注意力函数,支持所有模式。
|
|
||||||
|
|
||||||
```python
|
|
||||||
from block_sparse_attn import block_sparse_attn_func
|
|
||||||
|
|
||||||
output = block_sparse_attn_func(
|
|
||||||
q, k, v, # [total_tokens, heads, head_dim] unpadded
|
|
||||||
cu_seqlens_q, cu_seqlens_k, # cumulative sequence lengths
|
|
||||||
head_mask_type, # [heads] tensor, 每个头的模式
|
|
||||||
streaming_info, # streaming 配置 (sink/local 数量)
|
|
||||||
base_blockmask, # [q_blocks, k_blocks, n_masks] bool tensor
|
|
||||||
max_seqlen_q, max_seqlen_k, # 最大序列长度
|
|
||||||
p_dropout, # dropout 概率 (推理时设为 0.0)
|
|
||||||
deterministic=False,
|
|
||||||
softmax_scale=None,
|
|
||||||
is_causal=False,
|
|
||||||
exact_streaming=False, # True=token streaming, False=block streaming
|
|
||||||
return_attn_probs=False,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**关键参数**:
|
|
||||||
| 参数 | 类型 | 说明 |
|
|
||||||
|------|------|------|
|
|
||||||
| `head_mask_type` | Tensor[heads] | 每个头的稀疏模式,0=dense, -1=streaming, 1+=blocksparse |
|
|
||||||
| `streaming_info` | Tensor | [sink_blocks, local_blocks] 或 [sink_tokens, local_tokens] |
|
|
||||||
| `base_blockmask` | Tensor | Block mask,形状 [q_blocks, k_blocks, n_masks] |
|
|
||||||
| `exact_streaming` | bool | True=token 粒度,False=block 粒度 streaming |
|
|
||||||
|
|
||||||
### `block_streaming_attn_func`
|
|
||||||
|
|
||||||
Block 粒度 streaming attention(block_size=128)。
|
|
||||||
|
|
||||||
```python
|
|
||||||
from block_sparse_attn import block_streaming_attn_func
|
|
||||||
|
|
||||||
output = block_streaming_attn_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q, cu_seqlens_k,
|
|
||||||
head_mask_type,
|
|
||||||
streaming_info, # [sink_blocks, local_blocks]
|
|
||||||
max_seqlen_q, max_seqlen_k,
|
|
||||||
p_dropout,
|
|
||||||
deterministic=False,
|
|
||||||
softmax_scale=None,
|
|
||||||
is_causal=True,
|
|
||||||
return_attn_probs=False,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### `token_streaming_attn_func`
|
|
||||||
|
|
||||||
Token 粒度 streaming attention。
|
|
||||||
|
|
||||||
**注意**: 不支持反向传播(仅推理)。
|
|
||||||
|
|
||||||
```python
|
|
||||||
from block_sparse_attn import token_streaming_attn_func
|
|
||||||
|
|
||||||
output = token_streaming_attn_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q, cu_seqlens_k,
|
|
||||||
head_mask_type,
|
|
||||||
streaming_info, # [sink_tokens, local_tokens]
|
|
||||||
max_seqlen_q, max_seqlen_k,
|
|
||||||
deterministic=False,
|
|
||||||
softmax_scale=None,
|
|
||||||
return_attn_probs=False,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 技术规格
|
|
||||||
|
|
||||||
| 特性 | 支持情况 |
|
|
||||||
|------|----------|
|
|
||||||
| **数据类型** | fp16, bf16 (bf16 需要 Ampere/Ada/Hopper GPU) |
|
|
||||||
| **Head 维度** | 32, 64, 128 |
|
|
||||||
| **Block Size** | 128 (固定) |
|
|
||||||
| **CUDA 要求** | 11.6+ |
|
|
||||||
| **PyTorch 要求** | 1.12+ |
|
|
||||||
|
|
||||||
## 性能参考
|
|
||||||
|
|
||||||
测试环境: A100 GPU, head_dim=128, 32 heads, batch_size=1
|
|
||||||
|
|
||||||
### Block Sparse 加速比
|
|
||||||
- 相比 FlashAttention2: 最高 **3-4x** 加速
|
|
||||||
- 加速随序列长度增加而提升
|
|
||||||
|
|
||||||
### Streaming 混合模式加速比
|
|
||||||
- Token streaming: 64 sink + 256 local tokens
|
|
||||||
- Block streaming: 1 sink block + 3 local blocks
|
|
||||||
- **50% Dense + 50% Streaming**: 最高 **2x** 加速
|
|
||||||
|
|
||||||
## 与 nano-vllm 的集成考虑
|
|
||||||
|
|
||||||
### 潜在集成点
|
|
||||||
|
|
||||||
1. **长上下文推理优化**
|
|
||||||
- 使用 block streaming 减少计算量
|
|
||||||
- 在 CPU offload 模式下减少 GPU-CPU 传输
|
|
||||||
|
|
||||||
2. **混合注意力策略**
|
|
||||||
- 部分 head 使用 streaming(减少计算)
|
|
||||||
- 部分 head 使用 dense(保持精度)
|
|
||||||
- 参考 Duo Attention 论文的混合模式
|
|
||||||
|
|
||||||
3. **稀疏 offload**
|
|
||||||
- 只 offload 重要 blocks 的 KV cache
|
|
||||||
- 结合 `requires_block_selection` 接口
|
|
||||||
|
|
||||||
### 实现注意事项
|
|
||||||
|
|
||||||
1. **输入格式**: 库使用 unpadded 格式(`cu_seqlens`),需要与 nano-vllm 的 padded 格式转换
|
|
||||||
2. **Block size 固定**: 库固定 block_size=128,需要适配
|
|
||||||
3. **Streaming info 配置**: 需要根据模型特性调整 sink/local 数量
|
|
||||||
|
|
||||||
## 相关工作
|
|
||||||
|
|
||||||
- [FlashAttention](https://github.com/Dao-AILab/flash-attention) - 基础实现
|
|
||||||
- [StreamingLLM](https://arxiv.org/abs/2309.17453) - Streaming attention 理论基础
|
|
||||||
- [Duo Attention](https://github.com/mit-han-lab/duo-attention) - 混合 dense/streaming 模式
|
|
||||||
- [MInference](https://arxiv.org/abs/2407.02490) - 混合 mask 方法
|
|
||||||
|
|
||||||
## 测试
|
|
||||||
|
|
||||||
库自带测试位于 `3rdparty/Block-Sparse-Attention/block_sparse_tests/`:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 正确性测试
|
|
||||||
cd 3rdparty/Block-Sparse-Attention/block_sparse_tests/fwd/test_correctness
|
|
||||||
pytest full_test.py
|
|
||||||
|
|
||||||
# 性能测试
|
|
||||||
cd 3rdparty/Block-Sparse-Attention/block_sparse_tests/fwd/test_performance
|
|
||||||
python token_streaming.py
|
|
||||||
python blocksparse.py
|
|
||||||
```
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,354 +0,0 @@
|
|||||||
# Chunked Prefill 集成计划
|
|
||||||
|
|
||||||
**目标**: 将 tzj/minference 分支的 chunked prefill 机制移植到 tzj/vs_offload 分支
|
|
||||||
|
|
||||||
**创建日期**: 2026-01-18
|
|
||||||
**基础分支**: `tzj/vs_offload`
|
|
||||||
**源分支**: `tzj/minference`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 目标
|
|
||||||
|
|
||||||
在 tzj/vs_offload 分支上实现 chunked prefill + layerwise offload 机制,支持在 24GB RTX 3090 上运行任意长度的推理(4M, 8M, 16M+ tokens)。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 核心问题
|
|
||||||
|
|
||||||
### tzj/vs_offload 分支的局限性
|
|
||||||
|
|
||||||
当前 tzj/vs_offload 分支的 GPU ring buffer 按 `max_seq_len` 分配,导致 GPU 内存随序列长度线性增长:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 当前设计
|
|
||||||
self.layer_k_cache = torch.zeros(
|
|
||||||
num_kv_buffers, # e.g., 4
|
|
||||||
max_seq_len, # e.g., 131072 tokens
|
|
||||||
kv_heads,
|
|
||||||
head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**问题**:
|
|
||||||
- GPU 内存需求 ~ `max_seq_len × 4 × 8 × 128 × 2 bytes`
|
|
||||||
- 对于超长序列不可行:
|
|
||||||
- 4M tokens → ~64 GB GPU 内存 ❌
|
|
||||||
- 8M tokens → ~128 GB GPU 内存 ❌
|
|
||||||
|
|
||||||
### 解决方案:Block-Based 设计
|
|
||||||
|
|
||||||
tzj/minference 分支采用 block-based 设计,GPU 内存固定:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Block-based 设计
|
|
||||||
self.k_cache_gpu = torch.zeros(
|
|
||||||
num_gpu_blocks, # e.g., 2
|
|
||||||
block_size, # e.g., 1024 tokens (固定!)
|
|
||||||
kv_heads,
|
|
||||||
head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
)
|
|
||||||
# GPU 内存: ~4 MB (固定,不随序列长度增长)
|
|
||||||
```
|
|
||||||
|
|
||||||
**优势**:
|
|
||||||
- GPU 内存固定(~1.6 GB),不随序列长度增长
|
|
||||||
- 24GB RTX 3090 可运行 4M+ tokens
|
|
||||||
- 通过 chunked prefill 分块处理超长序列
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 内存布局对比
|
|
||||||
|
|
||||||
| 组件 | tzj/vs_offload | tzj/minference | 说明 |
|
|
||||||
|------|---------------|----------------|------|
|
|
||||||
| **GPU Ring Buffer** | `[num_kv_buffers, max_seq_len, ...]` | `[num_gpu_blocks, block_size, ...]` | minference 无 layer 维度 |
|
|
||||||
| **GPU 内存** | ~2.15 GB (128K) → ~64 GB (4M) | ~4 MB (固定) | minference 节省显著 |
|
|
||||||
| **Prefill Buffer** | ❌ 无 | ✅ `[num_layers, block_size, ...]` | minference 独有 |
|
|
||||||
| **Pipeline Buffers** | ❌ 无 | ✅ 双缓冲区 `[blocks, block_size, ...]` | minference 独有 |
|
|
||||||
| **CPU Cache** | `[num_layers, num_cpu_blocks, block_size, ...]` | 相同 | **一致** |
|
|
||||||
|
|
||||||
### 序列长度支持对比
|
|
||||||
|
|
||||||
| 序列长度 | vs_offload GPU 内存 | minference GPU 内存 | RTX 3090 (24GB) |
|
|
||||||
|----------|-------------------|---------------------|-----------------|
|
|
||||||
| 128K tokens | ~2.15 GB | ~4 MB | ✅ 两者均可 |
|
|
||||||
| 1M tokens | ~16 GB | ~4 MB | ✅ 两者均可 |
|
|
||||||
| **4M tokens** | **~64 GB** ❌ | **~4 MB** ✅ | **仅 minference 可行** |
|
|
||||||
| **8M tokens** | **~128 GB** ❌ | **~4 MB** ✅ | **仅 minference 可行** |
|
|
||||||
| **16M+ tokens** | **~256 GB+** ❌ | **~4 MB** ✅ | **仅 minference 可行** |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 关键设计原则
|
|
||||||
|
|
||||||
1. **Block-Based 设计**:按 `block_size` (1024 tokens) 组织,支持 chunked prefill
|
|
||||||
2. **GPU 内存固定**:不随序列长度增长,是 constant factor
|
|
||||||
3. **CPU 内存线性缩放**:`num_cpu_blocks = ceil(seq_len / block_size)`
|
|
||||||
4. **Unified Ring Buffer**:无 layer 维度,所有层共享 slots
|
|
||||||
5. **完全并行 offload**:per-layer buffer 最大化 PCIe 带宽
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 统一内存布局设计
|
|
||||||
|
|
||||||
### GPU Memory Layout
|
|
||||||
|
|
||||||
```python
|
|
||||||
class OffloadEngine:
|
|
||||||
# 1. Unified Ring Buffer - Block-based,无 layer 维度
|
|
||||||
self.k_cache_gpu = torch.zeros(
|
|
||||||
num_gpu_blocks, # e.g., 2
|
|
||||||
block_size, # e.g., 1024
|
|
||||||
kv_heads,
|
|
||||||
head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
) # ~4 MB (固定)
|
|
||||||
|
|
||||||
# 2. Per-layer Prefill Buffer - 完全并行 offload
|
|
||||||
self.prefill_k_buffer = torch.zeros(
|
|
||||||
num_layers, block_size, kv_heads, head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
) # ~58 MB (固定)
|
|
||||||
|
|
||||||
# 3. Cross-layer Pipeline Buffers - Double-buffering
|
|
||||||
self.layer_k_buffer_a = torch.zeros(
|
|
||||||
max_prefill_blocks, block_size, kv_heads, head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
) # ~512 MB (固定)
|
|
||||||
self.layer_k_buffer_b = torch.zeros(...) # ~512 MB (固定)
|
|
||||||
|
|
||||||
# 4. Per-layer Decode Buffer
|
|
||||||
self.decode_k_buffer = torch.zeros(
|
|
||||||
num_layers, block_size, kv_heads, head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
) # ~58 MB (固定)
|
|
||||||
|
|
||||||
# GPU 总计:~1.6 GB (固定,不随序列长度增长)
|
|
||||||
```
|
|
||||||
|
|
||||||
### CPU Memory Layout
|
|
||||||
|
|
||||||
```python
|
|
||||||
# CPU Cache - 有 block 维度
|
|
||||||
self.k_cache_cpu = torch.zeros(
|
|
||||||
num_layers,
|
|
||||||
num_cpu_blocks, # 随序列长度缩放
|
|
||||||
block_size,
|
|
||||||
kv_heads,
|
|
||||||
head_dim,
|
|
||||||
dtype=dtype, device="cpu", pin_memory=True
|
|
||||||
)
|
|
||||||
# 128K tokens: ~2.9 GB
|
|
||||||
# 1M tokens: ~5.8 GB
|
|
||||||
# 4M tokens: ~23.3 GB
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Chunked Prefill 流程
|
|
||||||
|
|
||||||
### Prefill 阶段
|
|
||||||
|
|
||||||
```
|
|
||||||
For each chunk:
|
|
||||||
├── 1. Prepare chunk input (block_size tokens)
|
|
||||||
├── 2. Get ring buffer slot: slot = chunk_idx % num_gpu_blocks
|
|
||||||
├── 3. Load previous KV chunks to ring slots[1..N-1]
|
|
||||||
├── 4. Model Forward (all layers)
|
|
||||||
│ For each layer:
|
|
||||||
│ ├── Load previous KV from ring slots
|
|
||||||
│ ├── Compute attention (current chunk + previous)
|
|
||||||
│ ├── Write KV to prefill_buffer[layer_id] ← Per-layer!
|
|
||||||
│ └── Async offload to CPU (parallel across layers)
|
|
||||||
├── 5. Merge attention outputs (LSE)
|
|
||||||
└── 6. Record compute done for slot
|
|
||||||
|
|
||||||
Key: Per-layer prefill buffer → Layer 0 offload || Layer 1 compute || Layer 2 load ...
|
|
||||||
```
|
|
||||||
|
|
||||||
### Decode 阶段
|
|
||||||
|
|
||||||
```
|
|
||||||
├── 1. Setup pipeline: preload Layer 0 to buffer_a
|
|
||||||
├── 2. For each layer:
|
|
||||||
│ ├── Get KV from pipeline buffer (a or b)
|
|
||||||
│ ├── Trigger preload of next layer to other buffer
|
|
||||||
│ ├── Compute attention
|
|
||||||
│ └── Store to decode buffer
|
|
||||||
└── 3. End pipeline
|
|
||||||
|
|
||||||
Key: Double-buffering → Layer N compute || Layer N+1 load
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 合并策略
|
|
||||||
|
|
||||||
### 基础分支选择:tzj/vs_offload
|
|
||||||
|
|
||||||
**原因**:
|
|
||||||
1. 更完善的文档系统
|
|
||||||
2. 更完整的 sparse attention 实现(QUEST, XAttention 等)
|
|
||||||
3. 更清晰的代码组织和注释
|
|
||||||
4. 更活跃的开发维护
|
|
||||||
|
|
||||||
### 移植策略
|
|
||||||
|
|
||||||
**从 tzj/minference 移植**:
|
|
||||||
1. GPU cache 内存布局(无 layer 维度,block-based)
|
|
||||||
2. Per-layer prefill buffer
|
|
||||||
3. Cross-layer pipeline buffers
|
|
||||||
4. Chunked prefill 流程
|
|
||||||
5. LSE 在线合并机制
|
|
||||||
|
|
||||||
**保留 tzj/vs_offload 优势**:
|
|
||||||
1. 文档系统
|
|
||||||
2. Sparse policy 架构
|
|
||||||
3. 代码组织和注释
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Sparse Policy 策略
|
|
||||||
|
|
||||||
**策略**:保留架构,现阶段仅实现 FULL
|
|
||||||
|
|
||||||
- **保留** sparse policy 的架构设计和接口
|
|
||||||
- **预留** 扩展接口给未来的 QUEST 等其他策略
|
|
||||||
- **现阶段仅实现** FULL 策略,确保正确性和稳定性
|
|
||||||
|
|
||||||
### 实现
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SparsePolicy(ABC):
|
|
||||||
@property
|
|
||||||
def supports_prefill(self) -> bool:
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_decode(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
|
|
||||||
"""预留给未来策略(如 QUEST 收集元数据)"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def select_blocks(self, available_blocks, context) -> List[int]:
|
|
||||||
"""FULL: 返回所有可用块"""
|
|
||||||
return available_blocks
|
|
||||||
|
|
||||||
class FullAttentionPolicy(SparsePolicy):
|
|
||||||
@property
|
|
||||||
def supports_prefill(self) -> bool:
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_decode(self) -> bool:
|
|
||||||
return True
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 关键 API
|
|
||||||
|
|
||||||
### Ring Buffer 管理
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Prefill 阶段
|
|
||||||
get_write_slot_for_prefill(chunk_idx) -> slot_idx
|
|
||||||
get_load_slots_for_prefill(write_slot_idx) -> [slot_ids]
|
|
||||||
|
|
||||||
# Decode 阶段
|
|
||||||
get_load_slots_for_decode() -> [slot_ids] (excludes decode_slot)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Per-layer 操作
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 加载
|
|
||||||
load_to_slot_layer(slot_idx, layer_id, cpu_block_id)
|
|
||||||
wait_slot_layer(slot_idx)
|
|
||||||
|
|
||||||
# Prefill buffer
|
|
||||||
get_prefill_buffer(layer_id) -> (k, v)
|
|
||||||
offload_prefill_buffer_async(layer_id, cpu_block_id, num_tokens)
|
|
||||||
wait_prefill_offload(layer_id)
|
|
||||||
|
|
||||||
# Pipeline
|
|
||||||
start_decode_pipeline(cpu_block_ids)
|
|
||||||
get_decode_layer_kv(layer_id, num_blocks) -> (k, v)
|
|
||||||
end_decode_pipeline()
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 实施阶段
|
|
||||||
|
|
||||||
### Phase 1: 内存布局重构
|
|
||||||
- 修改 GPU cache 为 unified ring buffer
|
|
||||||
- 添加 per-layer prefill buffer
|
|
||||||
- 添加 cross-layer pipeline buffers
|
|
||||||
|
|
||||||
### Phase 2: API 实现
|
|
||||||
- 实现 ring buffer slot 管理 API
|
|
||||||
- 实现 per-layer prefill offload API
|
|
||||||
- 实现 cross-layer pipeline API
|
|
||||||
|
|
||||||
### Phase 3: 集成到 Attention Layer
|
|
||||||
- 修改 attention forward 流程
|
|
||||||
- 集成 per-layer prefill buffer
|
|
||||||
- 集成 cross-layer pipeline
|
|
||||||
|
|
||||||
### Phase 4: 集成到 Model Runner
|
|
||||||
- 实现 chunked prefill 流程
|
|
||||||
- 集成 LSE 合并
|
|
||||||
- 优化流水线
|
|
||||||
|
|
||||||
### Phase 5: Sparse Policy 集成(FULL)
|
|
||||||
- 设计统一的策略接口
|
|
||||||
- 实现 FullAttentionPolicy
|
|
||||||
- 预留 QUEST 等未来策略的扩展接口
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 关键决策
|
|
||||||
|
|
||||||
1. **Block-Based 设计优先**:支持任意长度推理的核心
|
|
||||||
2. **采用 tzj/minference 的内存布局**:GPU cache 无 layer 维度 + block-based
|
|
||||||
3. **以 tzj/vs_offload 为基础分支**:更好的文档和代码组织
|
|
||||||
4. **分阶段合并策略**:降低复杂度,便于验证
|
|
||||||
5. **Sparse Policy - FULL 优先**:保留架构,现阶段仅实现 FULL
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 预期结果
|
|
||||||
|
|
||||||
### 内存使用(28层模型,block_size=1024)
|
|
||||||
|
|
||||||
| 组件 | 内存 |
|
|
||||||
|------|------|
|
|
||||||
| GPU Unified Ring Buffer | ~4 MB |
|
|
||||||
| GPU Per-layer Prefill Buffer | ~58 MB |
|
|
||||||
| GPU Pipeline Buffers (×2) | ~1 GB |
|
|
||||||
| GPU Decode Buffer | ~58 MB |
|
|
||||||
| **GPU 总计** | **~1.6 GB (固定)** |
|
|
||||||
| CPU Cache (4M tokens) | ~23.3 GB |
|
|
||||||
| **总计 (4M tokens)** | **~24.9 GB** ✅ 适配 24GB RTX 3090 |
|
|
||||||
|
|
||||||
### 性能支持
|
|
||||||
|
|
||||||
- ✅ 支持 4M, 8M, 16M+ tokens 的推理
|
|
||||||
- ✅ GPU 内存固定,不随序列长度增长
|
|
||||||
- ✅ 完全并行的 layerwise offload
|
|
||||||
- ✅ Cross-layer 流水线优化
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 参考
|
|
||||||
|
|
||||||
- **OffloadEngine**: `nanovllm/kvcache/offload_engine.py`
|
|
||||||
- **Attention Layer**: `nanovllm/layers/attention.py`
|
|
||||||
- **Model Runner**: `nanovllm/engine/model_runner.py`
|
|
||||||
- **Sparse Policy**: `nanovllm/kvcache/sparse/policy.py`
|
|
||||||
@@ -1,196 +0,0 @@
|
|||||||
# CUDA Graph Support for CPU Offload Mode
|
|
||||||
|
|
||||||
This document describes the CUDA graph implementation for the CPU offload decode path, which provides significant performance improvements for decode throughput.
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
CUDA graphs capture a sequence of GPU operations and replay them with minimal CPU overhead. In offload mode, we capture per-layer graphs for the decode path, achieving **4x decode throughput improvement**.
|
|
||||||
|
|
||||||
## Performance Results
|
|
||||||
|
|
||||||
| Metric | Eager Mode | CUDA Graph | Improvement |
|
|
||||||
|--------|------------|------------|-------------|
|
|
||||||
| Decode Throughput | ~12 tok/s | ~50 tok/s | **4.2x** |
|
|
||||||
| TPOT (Time per output token) | ~80ms | ~19ms | **4.2x** |
|
|
||||||
| Prefill Throughput | ~8000 tok/s | ~8000 tok/s | Same |
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
### Why Standard CUDA Graph Capture Doesn't Work
|
|
||||||
|
|
||||||
The standard `capture_cudagraph()` captures the PagedAttention decode path:
|
|
||||||
- Uses block tables for scattered KV cache access
|
|
||||||
- `Attention.k_cache/v_cache` point to PagedAttention buffers
|
|
||||||
|
|
||||||
In offload mode, the decode path is different:
|
|
||||||
- Uses contiguous ring buffers for KV cache
|
|
||||||
- `Attention.k_cache/v_cache` dynamically point to ring buffer slices
|
|
||||||
- H2D transfers interleaved with compute
|
|
||||||
|
|
||||||
### Per-Layer Graph Design
|
|
||||||
|
|
||||||
We capture one CUDA graph per transformer layer:
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────────────────────────────────┐
|
|
||||||
│ Offload Decode with CUDA Graphs │
|
|
||||||
├─────────────────────────────────────────────────────────────┤
|
|
||||||
│ │
|
|
||||||
│ Initialization: │
|
|
||||||
│ capture_offload_cudagraph() captures 36 layer graphs │
|
|
||||||
│ Each graph: layer.forward() with ring buffer as cache │
|
|
||||||
│ │
|
|
||||||
│ Decode Step: │
|
|
||||||
│ 1. Embedding (eager, outside graph) │
|
|
||||||
│ 2. For each layer: │
|
|
||||||
│ a. Wait for H2D load (outside graph) │
|
|
||||||
│ b. Copy decode KV to ring buffer (outside graph) │
|
|
||||||
│ c. Set Attention.k_cache = ring_buffer[buffer_idx] │
|
|
||||||
│ d. Set context (slot_mapping, context_lens) │
|
|
||||||
│ e. graph.replay() - layer forward │
|
|
||||||
│ f. synchronize() │
|
|
||||||
│ g. Copy layer_outputs -> hidden_states │
|
|
||||||
│ h. Copy new KV to decode buffer (outside graph) │
|
|
||||||
│ i. Start next layer H2D load │
|
|
||||||
│ 3. Final norm and logits (eager) │
|
|
||||||
│ │
|
|
||||||
└─────────────────────────────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### Ring Buffer Mapping
|
|
||||||
|
|
||||||
Each layer maps to a ring buffer slot:
|
|
||||||
```python
|
|
||||||
buffer_idx = layer_id % num_kv_buffers
|
|
||||||
```
|
|
||||||
|
|
||||||
With 4 buffers and 36 layers:
|
|
||||||
- Layer 0, 4, 8, ... use buffer 0
|
|
||||||
- Layer 1, 5, 9, ... use buffer 1
|
|
||||||
- Layer 2, 6, 10, ... use buffer 2
|
|
||||||
- Layer 3, 7, 11, ... use buffer 3
|
|
||||||
|
|
||||||
## Implementation Details
|
|
||||||
|
|
||||||
### Graph Capture (`capture_offload_cudagraph`)
|
|
||||||
|
|
||||||
Location: `model_runner.py:1075-1164`
|
|
||||||
|
|
||||||
```python
|
|
||||||
def capture_offload_cudagraph(self):
|
|
||||||
# Fixed-address tensors for graph I/O
|
|
||||||
hidden_states = torch.randn(1, hidden_size, ...)
|
|
||||||
residual = torch.randn(1, hidden_size, ...)
|
|
||||||
layer_outputs = torch.zeros(1, hidden_size, ...)
|
|
||||||
layer_residual = torch.zeros(1, hidden_size, ...)
|
|
||||||
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
buffer_idx = layer_id % num_buffers
|
|
||||||
|
|
||||||
# Set Attention cache to ring buffer slice
|
|
||||||
attn_module.k_cache = ring_buffer[buffer_idx:buffer_idx+1]
|
|
||||||
attn_module.v_cache = ring_buffer[buffer_idx:buffer_idx+1]
|
|
||||||
|
|
||||||
# Set context for contiguous mode
|
|
||||||
set_context(is_prefill=False, slot_mapping=...,
|
|
||||||
context_lens=..., block_tables=None)
|
|
||||||
|
|
||||||
# Warmup and capture
|
|
||||||
with torch.cuda.graph(graph, pool):
|
|
||||||
out_h, out_r = layer(positions, hidden_states, residual)
|
|
||||||
layer_outputs.copy_(out_h)
|
|
||||||
layer_residual.copy_(out_r)
|
|
||||||
|
|
||||||
# Propagate state for next layer's capture
|
|
||||||
hidden_states.copy_(layer_outputs)
|
|
||||||
residual.copy_(layer_residual)
|
|
||||||
```
|
|
||||||
|
|
||||||
Key design decisions:
|
|
||||||
1. **Fixed-address tensors**: Graph inputs/outputs use pre-allocated tensors
|
|
||||||
2. **Include copy in graph**: `layer_outputs.copy_(out_h)` is captured
|
|
||||||
3. **State propagation**: Update hidden_states between layer captures
|
|
||||||
4. **Random initialization**: Use `randn` instead of zeros for realistic distributions
|
|
||||||
|
|
||||||
### Graph Replay (`run_layerwise_offload_decode`)
|
|
||||||
|
|
||||||
Location: `model_runner.py:844-1031`
|
|
||||||
|
|
||||||
```python
|
|
||||||
use_cuda_graph = not self.enforce_eager and hasattr(self, 'offload_graphs')
|
|
||||||
|
|
||||||
if use_cuda_graph:
|
|
||||||
# Use fixed-address tensors
|
|
||||||
graph_vars["positions"][0] = len(seq) - 1
|
|
||||||
graph_vars["slot_mapping"][0] = context_len
|
|
||||||
graph_vars["context_lens"][0] = context_len + 1
|
|
||||||
graph_vars["hidden_states"].copy_(embedding)
|
|
||||||
graph_vars["residual"].zero_()
|
|
||||||
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
# H2D and buffer setup (outside graph)
|
|
||||||
offload_engine.wait_buffer_load(current_buffer)
|
|
||||||
attn_module.k_cache = ring_buffer[current_buffer:current_buffer+1]
|
|
||||||
set_context(...)
|
|
||||||
|
|
||||||
if use_cuda_graph:
|
|
||||||
# Replay graph
|
|
||||||
self.offload_graphs[layer_id].replay()
|
|
||||||
torch.cuda.current_stream().synchronize()
|
|
||||||
|
|
||||||
# Copy outputs to inputs for next layer
|
|
||||||
if layer_id < num_layers - 1:
|
|
||||||
graph_vars["hidden_states"].copy_(graph_vars["layer_outputs"])
|
|
||||||
graph_vars["residual"].copy_(graph_vars["layer_residual"])
|
|
||||||
else:
|
|
||||||
# Eager execution
|
|
||||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
|
||||||
```
|
|
||||||
|
|
||||||
Key points:
|
|
||||||
1. **Synchronization required**: `synchronize()` after each graph replay
|
|
||||||
2. **Manual state propagation**: Copy layer_outputs to hidden_states between replays
|
|
||||||
3. **H2D outside graph**: Ring buffer loads happen before graph replay
|
|
||||||
|
|
||||||
## Limitations and Future Work
|
|
||||||
|
|
||||||
### Current Limitations
|
|
||||||
|
|
||||||
1. **Per-layer sync overhead**: Each layer requires synchronization
|
|
||||||
2. **No kernel fusion across layers**: Each layer is a separate graph
|
|
||||||
3. **Fixed batch size**: Only supports batch_size=1 for offload
|
|
||||||
|
|
||||||
### Future Optimization: Full-Decode Graph
|
|
||||||
|
|
||||||
Potential improvement: Capture entire decode step as single graph
|
|
||||||
- Complete all H2D loads before graph
|
|
||||||
- Single graph covers all 36 layers
|
|
||||||
- Better kernel fusion, less CPU overhead
|
|
||||||
- More complex to implement (handle buffer rotation inside graph)
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
Run needle test with CUDA graph:
|
|
||||||
```bash
|
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_needle.py \
|
|
||||||
--input-len 32768 \
|
|
||||||
--enable-offload \
|
|
||||||
--use-cuda-graph
|
|
||||||
```
|
|
||||||
|
|
||||||
Run benchmark:
|
|
||||||
```bash
|
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python bench_offload.py \
|
|
||||||
--input-len 16384 \
|
|
||||||
--bench-all
|
|
||||||
```
|
|
||||||
|
|
||||||
## Files Modified
|
|
||||||
|
|
||||||
| File | Changes |
|
|
||||||
|------|---------|
|
|
||||||
| `model_runner.py:46-50` | Call `capture_offload_cudagraph()` for offload mode |
|
|
||||||
| `model_runner.py:69-73` | Clean up offload graph resources in `exit()` |
|
|
||||||
| `model_runner.py:844-1031` | Add CUDA graph support to `run_layerwise_offload_decode()` |
|
|
||||||
| `model_runner.py:1075-1164` | New `capture_offload_cudagraph()` method |
|
|
||||||
| `tests/test_needle.py` | Add `--use-cuda-graph` flag |
|
|
||||||
@@ -1,142 +0,0 @@
|
|||||||
# Debugging Guide
|
|
||||||
|
|
||||||
This document provides debugging techniques for nano-vLLM, including PyTorch hooks for capturing intermediate tensors.
|
|
||||||
|
|
||||||
## PyTorch Hooks for Debugging
|
|
||||||
|
|
||||||
### Hook Positions in Qwen3
|
|
||||||
|
|
||||||
```
|
|
||||||
decoder_layer
|
|
||||||
├── input_layernorm (RMSNorm)
|
|
||||||
├── self_attn (Qwen3Attention) ← Hook here for attention I/O after o_proj
|
|
||||||
│ ├── q_proj → q_norm → RoPE
|
|
||||||
│ ├── k_proj → k_norm → RoPE
|
|
||||||
│ ├── v_proj
|
|
||||||
│ ├── attn (Attention) ← Hook here for Q/K/V tensors
|
|
||||||
│ │ └── FlashAttention / SDPA
|
|
||||||
│ └── o_proj
|
|
||||||
├── post_attention_layernorm (RMSNorm)
|
|
||||||
└── mlp (Qwen3MLP)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Hook Types & Data Shapes
|
|
||||||
|
|
||||||
| Hook Position | Type | Captured Data |
|
|
||||||
|---------------|------|---------------|
|
|
||||||
| `self_attn` | post | `[batch, seq_len, hidden_size]` - after o_proj |
|
|
||||||
| `self_attn.attn` | pre | Q,K,V: `[seq_len, num_heads, head_dim]` - after RoPE |
|
|
||||||
| `self_attn.attn` | post | `[seq_len, num_heads, head_dim]` - before o_proj |
|
|
||||||
|
|
||||||
### Example: Capture Attention Outputs
|
|
||||||
|
|
||||||
```python
|
|
||||||
storage = {}
|
|
||||||
|
|
||||||
def make_hook(layer_id: int, storage: dict):
|
|
||||||
def hook(module, inputs, output):
|
|
||||||
if isinstance(output, tuple):
|
|
||||||
attn_output = output[0]
|
|
||||||
else:
|
|
||||||
attn_output = output
|
|
||||||
# nanovllm shape: [num_tokens, hidden_size] -> add batch dim
|
|
||||||
if attn_output.dim() == 2:
|
|
||||||
attn_output = attn_output.unsqueeze(0)
|
|
||||||
storage[layer_id] = attn_output.detach().clone()
|
|
||||||
return hook
|
|
||||||
|
|
||||||
# Register hooks
|
|
||||||
hooks = []
|
|
||||||
for layer_idx, layer in enumerate(model.model.layers):
|
|
||||||
hooks.append(layer.self_attn.register_forward_hook(make_hook(layer_idx, storage)))
|
|
||||||
|
|
||||||
# Run inference...
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
for hook in hooks:
|
|
||||||
hook.remove()
|
|
||||||
```
|
|
||||||
|
|
||||||
### Reference Implementation
|
|
||||||
|
|
||||||
Key files for comparison testing:
|
|
||||||
|
|
||||||
| File | Purpose |
|
|
||||||
|------|---------|
|
|
||||||
| `tests/modeling_qwen3.py` | Reference Qwen3 implementation (torch + transformers only) |
|
|
||||||
| `tests/test_needle_ref.py` | Reference needle test using custom Qwen3 |
|
|
||||||
| `tests/test_needle.py` | Needle-in-haystack test for nanovllm |
|
|
||||||
|
|
||||||
### Common Pitfalls
|
|
||||||
|
|
||||||
1. **Shape mismatch**: nanovllm uses `[num_tokens, ...]` while torch uses `[batch, seq_len, ...]`
|
|
||||||
2. **Hook position**: `self_attn` captures after o_proj, `self_attn.attn` captures before o_proj
|
|
||||||
3. **Output format**: nanovllm returns tuple `(attn_output, None)`, handle with `output[0]`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Memory Debugging
|
|
||||||
|
|
||||||
### Track Peak GPU Memory
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Reset stats before operation
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Run operation
|
|
||||||
outputs = llm.generate([prompt], sampling_params)
|
|
||||||
|
|
||||||
# Check peak
|
|
||||||
peak_gb = torch.cuda.max_memory_allocated() / 1024**3
|
|
||||||
print(f"Peak GPU memory: {peak_gb:.2f} GB")
|
|
||||||
```
|
|
||||||
|
|
||||||
### Monitor Memory During Execution
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
|
|
||||||
def memory_snapshot():
|
|
||||||
allocated = torch.cuda.memory_allocated() / 1024**3
|
|
||||||
reserved = torch.cuda.memory_reserved() / 1024**3
|
|
||||||
print(f"Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
|
|
||||||
|
|
||||||
# Add snapshots at key points in your code
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Comparing Outputs
|
|
||||||
|
|
||||||
### Needle-in-Haystack Test
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Test with CPU offload
|
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_needle.py --enable-offload --input-len 8192
|
|
||||||
|
|
||||||
# Test without CPU offload (GPU-only)
|
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_needle.py --input-len 8192
|
|
||||||
|
|
||||||
# Compare with reference implementation
|
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_needle_ref.py --input-len 8192
|
|
||||||
```
|
|
||||||
|
|
||||||
### Tensor Comparison
|
|
||||||
|
|
||||||
```python
|
|
||||||
def compare_tensors(a, b, name, rtol=1e-3, atol=1e-5):
|
|
||||||
if a.shape != b.shape:
|
|
||||||
print(f"{name}: Shape mismatch {a.shape} vs {b.shape}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
diff = (a - b).abs()
|
|
||||||
max_diff = diff.max().item()
|
|
||||||
mean_diff = diff.mean().item()
|
|
||||||
|
|
||||||
close = torch.allclose(a, b, rtol=rtol, atol=atol)
|
|
||||||
print(f"{name}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}, close={close}")
|
|
||||||
return close
|
|
||||||
```
|
|
||||||
@@ -1,324 +0,0 @@
|
|||||||
# Notes: Sparsity Integration into Layerwise Offload
|
|
||||||
|
|
||||||
## Current Architecture Analysis
|
|
||||||
|
|
||||||
### GPU-Only Path vs Offload Path
|
|
||||||
|
|
||||||
| Aspect | GPU-Only | Layerwise Offload |
|
|
||||||
|--------|----------|-------------------|
|
|
||||||
| KV Storage | GPU blocks (paged) | CPU pinned + GPU ring buffer |
|
|
||||||
| Prefill | All layers → then attention | Per-layer: attention → offload |
|
|
||||||
| Decode | FlashAttn with block table | Ring buffer H2D → FlashAttn |
|
|
||||||
| Sparse Support | MInference via `attention.py` | Not integrated |
|
|
||||||
|
|
||||||
### MInference Flow (GPU-Only)
|
|
||||||
|
|
||||||
```
|
|
||||||
attention.py:101-105:
|
|
||||||
if context.sparse_prefill_policy is not None:
|
|
||||||
o = context.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
|
|
||||||
|
|
||||||
minference.py:sparse_prefill_attention():
|
|
||||||
1. estimate_pattern(q, k, layer_id) -> vertical_indices, slash_indices
|
|
||||||
2. _triton_mixed_sparse_attention(q, k, v, indices)
|
|
||||||
3. return output
|
|
||||||
```
|
|
||||||
|
|
||||||
### Quest Flow (GPU Block Mode)
|
|
||||||
|
|
||||||
```
|
|
||||||
hybrid_manager.py (if using CPU offload with Quest):
|
|
||||||
select_blocks(available_blocks, ctx) -> selected block IDs
|
|
||||||
-> load selected blocks to GPU
|
|
||||||
-> standard FlashAttn with loaded blocks
|
|
||||||
```
|
|
||||||
|
|
||||||
### Layerwise Offload Prefill Flow
|
|
||||||
|
|
||||||
```
|
|
||||||
model_runner.py:run_layerwise_offload_prefill():
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
# QKV projection
|
|
||||||
q, k, v = qkv_proj(hidden_ln)
|
|
||||||
|
|
||||||
# RoPE
|
|
||||||
q, k = rotary_emb(positions, q, k)
|
|
||||||
|
|
||||||
# FULL attention (no sparsity!)
|
|
||||||
attn_output = flash_attn_varlen_func(q, k, v, ...)
|
|
||||||
|
|
||||||
# MLP
|
|
||||||
hidden_states = mlp(attn_out + residual)
|
|
||||||
|
|
||||||
# Sync offload ALL k, v to CPU
|
|
||||||
for block_id in cpu_block_ids:
|
|
||||||
k_cache_cpu[layer_id, block_id].copy_(k[start:end])
|
|
||||||
v_cache_cpu[layer_id, block_id].copy_(v[start:end])
|
|
||||||
```
|
|
||||||
|
|
||||||
### Layerwise Offload Decode Flow
|
|
||||||
|
|
||||||
```
|
|
||||||
model_runner.py:run_layerwise_offload_decode():
|
|
||||||
# Preload first N layers to ring buffer
|
|
||||||
for i in range(num_buffers):
|
|
||||||
offload_engine.load_layer_kv_to_buffer(i, i, cpu_block_table, valid_tokens)
|
|
||||||
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
current_buffer = layer_id % num_buffers
|
|
||||||
|
|
||||||
# Wait for buffer load
|
|
||||||
offload_engine.wait_buffer_load(current_buffer)
|
|
||||||
|
|
||||||
# Get prefilled KV from ring buffer (ALL blocks loaded)
|
|
||||||
k_prefill, v_prefill = offload_engine.get_buffer_kv(current_buffer, total_prefill_tokens)
|
|
||||||
|
|
||||||
# QKV for new token
|
|
||||||
q, k_new, v_new = qkv_proj(hidden_ln)
|
|
||||||
|
|
||||||
# Concat and full attention
|
|
||||||
k_full = torch.cat([k_prefill, k_decode_prev, k_new])
|
|
||||||
attn_output = flash_attn_varlen_func(q, k_full, v_full, ...)
|
|
||||||
|
|
||||||
# Start loading next layer
|
|
||||||
offload_engine.load_layer_kv_to_buffer(current_buffer, layer_id + num_buffers, ...)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Integration Points
|
|
||||||
|
|
||||||
### 1. Prefill Sparse Integration Point
|
|
||||||
|
|
||||||
**Location:** `model_runner.py:535-543`
|
|
||||||
|
|
||||||
**Current:**
|
|
||||||
```python
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=total_tokens,
|
|
||||||
max_seqlen_k=total_tokens,
|
|
||||||
softmax_scale=layer.self_attn.attn.scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**After Integration:**
|
|
||||||
```python
|
|
||||||
if self.sparse_policy and self.sparse_policy.supports_offload_prefill:
|
|
||||||
attn_output, k_sparse, v_sparse = self.sparse_policy.offload_prefill_attention(
|
|
||||||
q, k, v, layer_id
|
|
||||||
)
|
|
||||||
k_to_offload = k_sparse if k_sparse is not None else k
|
|
||||||
v_to_offload = v_sparse if v_sparse is not None else v
|
|
||||||
else:
|
|
||||||
attn_output = flash_attn_varlen_func(q, k, v, ...)
|
|
||||||
k_to_offload, v_to_offload = k, v
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Decode Sparse Integration Point
|
|
||||||
|
|
||||||
**Location:** `model_runner.py:636-637` and `model_runner.py:704-706`
|
|
||||||
|
|
||||||
**Current (preload):**
|
|
||||||
```python
|
|
||||||
for i in range(num_preload):
|
|
||||||
offload_engine.load_layer_kv_to_buffer(
|
|
||||||
i, i, cpu_block_table, valid_tokens_per_block
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**After Integration:**
|
|
||||||
```python
|
|
||||||
for i in range(num_preload):
|
|
||||||
layer_to_load = i
|
|
||||||
if self.sparse_policy and self.sparse_policy.supports_offload_decode:
|
|
||||||
# Prepare q for this layer (need to compute ahead)
|
|
||||||
# OR: use previous layer's pattern as estimate
|
|
||||||
selected_blocks = self.sparse_policy.select_offload_blocks(
|
|
||||||
None, # q not available yet at preload
|
|
||||||
layer_to_load,
|
|
||||||
cpu_block_table,
|
|
||||||
valid_tokens_per_block
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
selected_blocks = cpu_block_table
|
|
||||||
offload_engine.load_sparse_layer_kv_to_buffer(
|
|
||||||
i, layer_to_load, selected_blocks, valid_tokens_per_block
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Challenge:** Q is not available during preload phase!
|
|
||||||
|
|
||||||
**Solutions:**
|
|
||||||
1. Skip sparse preload, only sparse for non-preloaded layers
|
|
||||||
2. Use previous decode step's pattern as estimate
|
|
||||||
3. Add preload hook to sparse policy
|
|
||||||
|
|
||||||
### 3. Offload Engine Extension
|
|
||||||
|
|
||||||
**New Method in OffloadEngine:**
|
|
||||||
|
|
||||||
```python
|
|
||||||
def load_sparse_layer_kv_to_buffer(
|
|
||||||
self,
|
|
||||||
buffer_idx: int,
|
|
||||||
layer_id: int,
|
|
||||||
selected_cpu_block_ids: List[int],
|
|
||||||
original_valid_tokens: List[int],
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Load only selected blocks from CPU to buffer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total tokens loaded (may be less than full sequence)
|
|
||||||
"""
|
|
||||||
stream = self.layer_load_streams[buffer_idx]
|
|
||||||
|
|
||||||
with torch.cuda.stream(stream):
|
|
||||||
stream.wait_event(self.buffer_compute_done_events[buffer_idx])
|
|
||||||
|
|
||||||
# Build mapping: original block -> selected position
|
|
||||||
offset = 0
|
|
||||||
for i, cpu_block_id in enumerate(selected_cpu_block_ids):
|
|
||||||
# Find original index to get valid tokens
|
|
||||||
valid_tokens = original_valid_tokens[i] # Need mapping
|
|
||||||
|
|
||||||
self.layer_k_cache[buffer_idx, offset:offset+valid_tokens].copy_(
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id, :valid_tokens],
|
|
||||||
non_blocking=True
|
|
||||||
)
|
|
||||||
# ... v_cache same
|
|
||||||
|
|
||||||
offset += valid_tokens
|
|
||||||
|
|
||||||
self.buffer_load_events[buffer_idx].record(stream)
|
|
||||||
|
|
||||||
return offset # Caller needs to know actual loaded tokens
|
|
||||||
```
|
|
||||||
|
|
||||||
## Metadata Flow for Quest
|
|
||||||
|
|
||||||
### During Prefill Offload
|
|
||||||
|
|
||||||
**Current:** No metadata collection in offload path
|
|
||||||
|
|
||||||
**Required:** Call `on_prefill_offload()` for each block
|
|
||||||
|
|
||||||
```python
|
|
||||||
# In run_layerwise_offload_prefill()
|
|
||||||
for i, cpu_block_id in enumerate(cpu_block_ids):
|
|
||||||
start = i * block_size
|
|
||||||
end = min(start + block_size, total_tokens)
|
|
||||||
actual_size = end - start
|
|
||||||
|
|
||||||
# BEFORE offload: update Quest metadata
|
|
||||||
if self.sparse_policy and hasattr(self.sparse_policy, 'on_prefill_offload'):
|
|
||||||
self.sparse_policy.on_prefill_offload(
|
|
||||||
cpu_block_id, layer_id, k[start:end], actual_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Offload
|
|
||||||
offload_engine.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
|
|
||||||
offload_engine.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
|
|
||||||
```
|
|
||||||
|
|
||||||
### Quest Metadata Shape
|
|
||||||
|
|
||||||
```python
|
|
||||||
# BlockMetadataManager
|
|
||||||
key_min: [num_blocks, num_layers, num_kv_heads, head_dim] # Min key per block per layer
|
|
||||||
key_max: [num_blocks, num_layers, num_kv_heads, head_dim] # Max key per block per layer
|
|
||||||
```
|
|
||||||
|
|
||||||
**Memory:** 2 * num_blocks * num_layers * kv_heads * head_dim * 2 bytes
|
|
||||||
- Example: 1000 blocks * 28 layers * 4 heads * 128 dim * 2 * 2 = ~57 MB
|
|
||||||
|
|
||||||
## Performance Considerations
|
|
||||||
|
|
||||||
### MInference Prefill Overhead
|
|
||||||
|
|
||||||
| Operation | Time (64K seq) |
|
|
||||||
|-----------|----------------|
|
|
||||||
| Pattern estimation (last-64) | ~5ms |
|
|
||||||
| Triton sparse attention | ~80ms |
|
|
||||||
| Full FlashAttention | ~100ms |
|
|
||||||
| **Net Speedup** | ~15-20% |
|
|
||||||
|
|
||||||
### Quest Decode Overhead
|
|
||||||
|
|
||||||
| Operation | Time |
|
|
||||||
|-----------|------|
|
|
||||||
| Block scoring (GPU metadata) | ~0.1ms |
|
|
||||||
| Top-K selection | ~0.05ms |
|
|
||||||
| Sparse H2D load (8 blocks) | ~2ms |
|
|
||||||
| Full H2D load (100 blocks) | ~20ms |
|
|
||||||
| **Net Speedup** | ~10x H2D |
|
|
||||||
|
|
||||||
### Memory Trade-offs
|
|
||||||
|
|
||||||
| Mode | GPU Memory | CPU Memory | H2D Bandwidth |
|
|
||||||
|------|------------|------------|---------------|
|
|
||||||
| Full offload | Ring buffer | Full KV | High |
|
|
||||||
| Sparse offload | Ring buffer | Full KV | Low (subset) |
|
|
||||||
| Aggressive sparse | Ring buffer | Sparse KV | Very low |
|
|
||||||
|
|
||||||
## Edge Cases
|
|
||||||
|
|
||||||
### 1. Short Sequences (< sparse threshold)
|
|
||||||
|
|
||||||
```python
|
|
||||||
if total_tokens < sparse_threshold:
|
|
||||||
# Fall back to full attention
|
|
||||||
use_sparse = False
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. First Decode Step (no previous Q)
|
|
||||||
|
|
||||||
Quest can't score blocks without Q. Options:
|
|
||||||
- Use average embedding as proxy
|
|
||||||
- Load all blocks for first step
|
|
||||||
- Use prefill pattern as estimate
|
|
||||||
|
|
||||||
### 3. Variable Sequence Lengths in Batch
|
|
||||||
|
|
||||||
Layerwise offload currently only supports batch_size=1:
|
|
||||||
```python
|
|
||||||
assert len(seqs) == 1, "Layer-wise offload only supports single sequence"
|
|
||||||
```
|
|
||||||
|
|
||||||
Sparse integration should maintain this constraint.
|
|
||||||
|
|
||||||
### 4. Ring Buffer vs Sparse Load Mismatch
|
|
||||||
|
|
||||||
Ring buffer assumes fixed `total_prefill_tokens`:
|
|
||||||
```python
|
|
||||||
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, total_prefill_tokens)
|
|
||||||
```
|
|
||||||
|
|
||||||
Sparse load has variable token count. Need:
|
|
||||||
```python
|
|
||||||
# Track actual loaded tokens per buffer
|
|
||||||
loaded_tokens[buffer_idx] = sparse_load_count
|
|
||||||
k_prefill, v_prefill = offload_engine.get_buffer_kv(buffer_idx, loaded_tokens[buffer_idx])
|
|
||||||
```
|
|
||||||
|
|
||||||
## Testing Strategy
|
|
||||||
|
|
||||||
### Unit Tests
|
|
||||||
|
|
||||||
1. `test_sparse_policy_interface.py` - Verify new interface methods
|
|
||||||
2. `test_minference_offload.py` - MInference in offload mode
|
|
||||||
3. `test_quest_offload.py` - Quest block selection in offload mode
|
|
||||||
|
|
||||||
### Integration Tests
|
|
||||||
|
|
||||||
1. `test_offload_sparse_e2e.py` - Full prefill+decode with sparsity
|
|
||||||
2. `test_accuracy_comparison.py` - Compare outputs: full vs sparse
|
|
||||||
|
|
||||||
### Benchmarks
|
|
||||||
|
|
||||||
1. `bench_offload_sparse.py` - Compare:
|
|
||||||
- Full offload (baseline)
|
|
||||||
- MInference prefill + Quest decode
|
|
||||||
- Aggressive sparse offload
|
|
||||||
@@ -1,194 +0,0 @@
|
|||||||
# GPU-only Performance Issue: PagedAttention Scatter Overhead
|
|
||||||
|
|
||||||
## Problem Summary
|
|
||||||
|
|
||||||
GPU-only mode with MInference is **slower** than CPU offload mode for long-context single-sequence inference:
|
|
||||||
|
|
||||||
| Mode | Prefill Speed (32K tokens, Qwen3-4B) |
|
|
||||||
|------|--------------------------------------|
|
|
||||||
| GPU-only + MInference | 3383 tok/s |
|
|
||||||
| Offload + MInference | 5373 tok/s |
|
|
||||||
|
|
||||||
This counterintuitive result is caused by **unnecessary `store_kvcache` overhead** in the GPU-only path.
|
|
||||||
|
|
||||||
## Root Cause Analysis
|
|
||||||
|
|
||||||
### GPU-only Execution Path
|
|
||||||
|
|
||||||
```python
|
|
||||||
# attention.py line 86-110
|
|
||||||
def forward(self, q, k, v):
|
|
||||||
# ALWAYS store to cache first - OVERHEAD HERE
|
|
||||||
if k_cache.numel() and v_cache.numel():
|
|
||||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping) # ← Always executed
|
|
||||||
|
|
||||||
if context.is_prefill:
|
|
||||||
if context.sparse_prefill_policy is not None:
|
|
||||||
# MInference: uses k, v directly, NOT k_cache!
|
|
||||||
o = sparse_prefill_attention(q, k, v, layer_id)
|
|
||||||
else:
|
|
||||||
# Full attention: also uses k, v directly
|
|
||||||
o = flash_attn_varlen_func(q, k, v, ...)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Key observation**: Prefill attention **never reads from cache** - it uses the computed k, v directly. But `store_kvcache` is always called before attention.
|
|
||||||
|
|
||||||
### The `store_kvcache` Overhead
|
|
||||||
|
|
||||||
```python
|
|
||||||
# attention.py line 8-59
|
|
||||||
def store_kvcache(key, value, k_cache, v_cache, slot_mapping):
|
|
||||||
# 1. Filter invalid slots (conditional logic)
|
|
||||||
valid_mask = slot_mapping >= 0
|
|
||||||
valid_slots = slot_mapping[valid_mask]
|
|
||||||
valid_keys = key[valid_mask]
|
|
||||||
|
|
||||||
# 2. Reshape for scatter operation
|
|
||||||
k_cache_flat = k_cache.view(total_slots, D)
|
|
||||||
valid_keys_flat = valid_keys.reshape(-1, D)
|
|
||||||
|
|
||||||
# 3. Scatter write via index_copy_ - EXPENSIVE!
|
|
||||||
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
|
|
||||||
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
|
|
||||||
```
|
|
||||||
|
|
||||||
This scatter operation is called for **every layer** (28 layers for Qwen3-4B), writing **all tokens** (32K) to GPU cache.
|
|
||||||
|
|
||||||
### Offload Path (No Such Overhead)
|
|
||||||
|
|
||||||
```python
|
|
||||||
# model_runner.py - run_layerwise_offload_prefill
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
# QKV projection + RoPE
|
|
||||||
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
|
||||||
|
|
||||||
# Sparse attention - directly uses k, v
|
|
||||||
attn_output = sparse_prefill_attention(q, k, v, layer_id)
|
|
||||||
|
|
||||||
# Contiguous copy to CPU - no scatter!
|
|
||||||
offload_engine.offload_layer_kv_sync(layer_id, k, v, cpu_block_ids, total_tokens)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Memory Layout Comparison
|
|
||||||
|
|
||||||
| Aspect | GPU-only (PagedAttention) | Offload (Contiguous) |
|
|
||||||
|--------|---------------------------|----------------------|
|
|
||||||
| **Layout** | `[num_blocks, block_size, heads, dim]` | `[seq_len, heads, dim]` |
|
|
||||||
| **Write pattern** | Scatter via `index_copy_` | Contiguous `copy_()` |
|
|
||||||
| **Indirection** | slot_mapping lookup | None |
|
|
||||||
| **Memory efficiency** | High (shared block pool) | Low (reserved per seq) |
|
|
||||||
| **Write performance** | Slow (memory-bound scatter) | Fast (simple DMA) |
|
|
||||||
|
|
||||||
### Why PagedAttention Uses Scatter
|
|
||||||
|
|
||||||
PagedAttention is designed for:
|
|
||||||
1. **Multi-sequence batching**: Different sequences share a block pool
|
|
||||||
2. **Dynamic memory management**: No need to reserve max_len per sequence
|
|
||||||
3. **Prefix caching**: Shared KV blocks across sequences
|
|
||||||
|
|
||||||
But for **single-sequence long-context** inference, these benefits don't apply, and we only pay the scatter overhead.
|
|
||||||
|
|
||||||
## Why `store_kvcache` is Still Needed
|
|
||||||
|
|
||||||
Even though prefill attention doesn't read from cache, **decode** does:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# attention.py line 111-114
|
|
||||||
else: # decode
|
|
||||||
# Reads from cache!
|
|
||||||
o = flash_attn_with_kvcache(q, k_cache, v_cache, block_table=...)
|
|
||||||
```
|
|
||||||
|
|
||||||
So `store_kvcache` during prefill is preparing KV cache for future decode steps.
|
|
||||||
|
|
||||||
## Potential Optimizations
|
|
||||||
|
|
||||||
### Option 1: Async Store After Attention (Low Effort)
|
|
||||||
|
|
||||||
Move `store_kvcache` after attention computation and make it async:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def forward(self, q, k, v):
|
|
||||||
if context.is_prefill:
|
|
||||||
# Compute attention first
|
|
||||||
if context.sparse_prefill_policy is not None:
|
|
||||||
o = sparse_prefill_attention(q, k, v, layer_id)
|
|
||||||
else:
|
|
||||||
o = flash_attn_varlen_func(q, k, v, ...)
|
|
||||||
|
|
||||||
# Then store async (overlaps with next layer's QKV)
|
|
||||||
if k_cache.numel():
|
|
||||||
store_kvcache_async(k, v, k_cache, v_cache, slot_mapping)
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
**Expected benefit**: Overlap store with compute, ~20-30% improvement.
|
|
||||||
|
|
||||||
### Option 2: Contiguous Layout for Single-Sequence Mode (Medium Effort)
|
|
||||||
|
|
||||||
Add a "contiguous mode" for single-sequence long-context:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class ContiguousKVCache:
|
|
||||||
"""Simple contiguous KV cache for single-sequence mode."""
|
|
||||||
def __init__(self, num_layers, max_seq_len, num_kv_heads, head_dim, dtype):
|
|
||||||
self.k_cache = torch.zeros(num_layers, max_seq_len, num_kv_heads, head_dim, dtype=dtype)
|
|
||||||
self.v_cache = torch.zeros(num_layers, max_seq_len, num_kv_heads, head_dim, dtype=dtype)
|
|
||||||
|
|
||||||
def store(self, layer_id, k, v, start_pos):
|
|
||||||
# Simple contiguous write - no scatter!
|
|
||||||
seq_len = k.shape[0]
|
|
||||||
self.k_cache[layer_id, start_pos:start_pos+seq_len] = k
|
|
||||||
self.v_cache[layer_id, start_pos:start_pos+seq_len] = v
|
|
||||||
```
|
|
||||||
|
|
||||||
**Expected benefit**: Match or exceed offload performance (~60% improvement).
|
|
||||||
|
|
||||||
### Option 3: Fused Store-Attention Kernel (High Effort)
|
|
||||||
|
|
||||||
Create a fused Triton kernel that:
|
|
||||||
1. Computes QKV projection
|
|
||||||
2. Stores K, V to cache
|
|
||||||
3. Computes attention
|
|
||||||
|
|
||||||
This eliminates memory roundtrips entirely.
|
|
||||||
|
|
||||||
**Expected benefit**: Best possible performance, but high implementation complexity.
|
|
||||||
|
|
||||||
## Recommended Action
|
|
||||||
|
|
||||||
For **single-sequence long-context** workloads (the primary use case for MInference):
|
|
||||||
|
|
||||||
1. **Short term**: Use offload mode - it's actually faster!
|
|
||||||
2. **Medium term**: Implement Option 1 (async store) for quick win
|
|
||||||
3. **Long term**: Consider Option 2 (contiguous layout) for GPU-only mode
|
|
||||||
|
|
||||||
## Performance Measurement
|
|
||||||
|
|
||||||
To reproduce the benchmark:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# GPU-only + MInference
|
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_needle.py \
|
|
||||||
--model ~/models/Qwen3-4B-Instruct-2507/ \
|
|
||||||
--input-len 32768 \
|
|
||||||
--enable-minference
|
|
||||||
|
|
||||||
# Offload + MInference
|
|
||||||
PYTHONPATH=/path/to/nano-vllm:$PYTHONPATH python tests/test_needle.py \
|
|
||||||
--model ~/models/Qwen3-4B-Instruct-2507/ \
|
|
||||||
--input-len 32768 \
|
|
||||||
--enable-offload \
|
|
||||||
--enable-minference
|
|
||||||
```
|
|
||||||
|
|
||||||
## Related Files
|
|
||||||
|
|
||||||
- `nanovllm/layers/attention.py`: `store_kvcache()` and `Attention.forward()`
|
|
||||||
- `nanovllm/engine/model_runner.py`: `run_layerwise_offload_prefill()`
|
|
||||||
- `nanovllm/kvcache/offload_engine.py`: `offload_layer_kv_sync()`
|
|
||||||
|
|
||||||
## References
|
|
||||||
|
|
||||||
- [PagedAttention Paper](https://arxiv.org/abs/2309.06180) - vLLM's memory management
|
|
||||||
- [MInference Paper](https://arxiv.org/abs/2407.02490) - Sparse prefill attention
|
|
||||||
@@ -1,547 +0,0 @@
|
|||||||
# Layer-wise Offload Memory Analysis
|
|
||||||
|
|
||||||
This document provides a detailed analysis of memory allocations in the layer-wise CPU offload system, distinguishing between pre-allocated (managed) memory and temporary (non-pre-allocated) memory.
|
|
||||||
|
|
||||||
## Variable Notation
|
|
||||||
|
|
||||||
| Symbol | Description | Example (Qwen3-4B) |
|
|
||||||
|--------|-------------|-------------------|
|
|
||||||
| `seq_len` | Input sequence length | 131072 (128k) |
|
|
||||||
| `hidden_size` | Model hidden dimension | 2560 |
|
|
||||||
| `num_heads` | Number of attention heads | 20 |
|
|
||||||
| `num_kv_heads` | Number of KV heads (GQA) | 8 |
|
|
||||||
| `head_dim` | Dimension per head | 128 |
|
|
||||||
| `intermediate_size` | MLP intermediate dimension | 13696 |
|
|
||||||
| `num_layers` | Number of transformer layers | 36 |
|
|
||||||
| `block_size` | KV cache block size | 1024 |
|
|
||||||
| `num_kv_buffers` | Ring buffer count | 4 |
|
|
||||||
| `num_cpu_blocks` | Number of CPU cache blocks | 128 |
|
|
||||||
| `vocab_size` | Vocabulary size | 151936 |
|
|
||||||
| `dtype_size` | Bytes per element (fp16/bf16) | 2 |
|
|
||||||
|
|
||||||
Derived values:
|
|
||||||
- `kv_dim = num_kv_heads × head_dim`
|
|
||||||
- `q_size = num_heads × head_dim`
|
|
||||||
- `kv_size = num_kv_heads × head_dim`
|
|
||||||
- `qkv_size = q_size + 2 × kv_size`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. Pre-allocated Memory (Managed by nanovllm)
|
|
||||||
|
|
||||||
These tensors are allocated once during initialization and reused throughout inference.
|
|
||||||
|
|
||||||
### 1.1 OffloadEngine Managed Memory
|
|
||||||
|
|
||||||
| Tensor | Shape | Size Formula | Location |
|
|
||||||
|--------|-------|--------------|----------|
|
|
||||||
| `layer_k_cache` | `[num_kv_buffers, seq_len, num_kv_heads, head_dim]` | `num_kv_buffers × seq_len × kv_dim × dtype_size` | GPU |
|
|
||||||
| `layer_v_cache` | `[num_kv_buffers, seq_len, num_kv_heads, head_dim]` | `num_kv_buffers × seq_len × kv_dim × dtype_size` | GPU |
|
|
||||||
| `decode_k_buffer` | `[num_layers, block_size, num_kv_heads, head_dim]` | `num_layers × block_size × kv_dim × dtype_size` | GPU |
|
|
||||||
| `decode_v_buffer` | `[num_layers, block_size, num_kv_heads, head_dim]` | `num_layers × block_size × kv_dim × dtype_size` | GPU |
|
|
||||||
| `k_cache_cpu` | `[num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim]` | `num_layers × num_cpu_blocks × block_size × kv_dim × dtype_size` | CPU (pinned) |
|
|
||||||
| `v_cache_cpu` | `[num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim]` | `num_layers × num_cpu_blocks × block_size × kv_dim × dtype_size` | CPU (pinned) |
|
|
||||||
|
|
||||||
**Total GPU (OffloadEngine)**: `2 × (num_kv_buffers × seq_len + num_layers × block_size) × kv_dim × dtype_size`
|
|
||||||
|
|
||||||
**Total CPU (OffloadEngine)**: `2 × num_layers × num_cpu_blocks × block_size × kv_dim × dtype_size`
|
|
||||||
|
|
||||||
### 1.2 Model Weights
|
|
||||||
|
|
||||||
| Component | Approximate Size |
|
|
||||||
|-----------|-----------------|
|
|
||||||
| Embedding | `vocab_size × hidden_size × dtype_size` |
|
|
||||||
| Per-layer QKV proj | `hidden_size × qkv_size × dtype_size` |
|
|
||||||
| Per-layer O proj | `q_size × hidden_size × dtype_size` |
|
|
||||||
| Per-layer MLP | `hidden_size × 2 × intermediate_size × dtype_size + intermediate_size × hidden_size × dtype_size` |
|
|
||||||
| Per-layer LayerNorm | `2 × hidden_size × dtype_size` |
|
|
||||||
| LM Head | `hidden_size × vocab_size × dtype_size` |
|
|
||||||
|
|
||||||
### 1.3 RoPE Cache
|
|
||||||
|
|
||||||
| Tensor | Shape | Size |
|
|
||||||
|--------|-------|------|
|
|
||||||
| `cos_sin_cache` | `[max_position, 1, head_dim]` | `max_position × head_dim × 4` (float32) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. Non-Pre-allocated Memory: Prefill Phase
|
|
||||||
|
|
||||||
Location: `model_runner.py:run_layerwise_offload_prefill()`
|
|
||||||
|
|
||||||
### 2.1 Persistent Tensors (Live Throughout Prefill)
|
|
||||||
|
|
||||||
| Variable | Line | Shape | Size | Notes |
|
|
||||||
|----------|------|-------|------|-------|
|
|
||||||
| `input_ids` | 488 | `[seq_len]` | `seq_len × 8` | int64 |
|
|
||||||
| `positions` | 489 | `[seq_len]` | `seq_len × 8` | int64 |
|
|
||||||
| `cu_seqlens` | 493 | `[2]` | negligible | int32 |
|
|
||||||
| `hidden_states` | 497 | `[seq_len, hidden_size]` | `seq_len × hidden_size × dtype_size` | Embedding output |
|
|
||||||
| `residual` | 506 | `[seq_len, hidden_size]` | `seq_len × hidden_size × dtype_size` | Residual connection |
|
|
||||||
|
|
||||||
### 2.2 Per-Layer Temporary Tensors
|
|
||||||
|
|
||||||
These are allocated and deallocated within each layer iteration.
|
|
||||||
|
|
||||||
#### 2.2.1 LayerNorm
|
|
||||||
|
|
||||||
| Variable | Line | Shape | Size | Notes |
|
|
||||||
|----------|------|-------|------|-------|
|
|
||||||
| `hidden_ln` | 506-508 | `[seq_len, hidden_size]` | `seq_len × hidden_size × dtype_size` | Input layernorm output |
|
|
||||||
|
|
||||||
**Inside RMSNorm** (`layernorm.py:add_rms_forward`):
|
|
||||||
| Variable | Shape | Size | Notes |
|
|
||||||
|----------|-------|------|-------|
|
|
||||||
| `x.float()` | `[seq_len, hidden_size]` | `seq_len × hidden_size × 4` | Upcasted to float32 |
|
|
||||||
| `var` | `[seq_len, 1]` | `seq_len × 4` | Variance |
|
|
||||||
|
|
||||||
#### 2.2.2 QKV Projection
|
|
||||||
|
|
||||||
| Variable | Line | Shape | Size | Notes |
|
|
||||||
|----------|------|-------|------|-------|
|
|
||||||
| `qkv` | 512 | `[seq_len, q_size + 2 × kv_size]` | `seq_len × qkv_size × dtype_size` | Merged QKV output |
|
|
||||||
| `q` | 513-519 | `[seq_len, num_heads, head_dim]` | 0 (view) | View of qkv |
|
|
||||||
| `k` | 513-520 | `[seq_len, num_kv_heads, head_dim]` | 0 (view) | View of qkv |
|
|
||||||
| `v` | 513-521 | `[seq_len, num_kv_heads, head_dim]` | 0 (view) | View of qkv |
|
|
||||||
|
|
||||||
#### 2.2.3 Q/K Norms (Qwen3 specific)
|
|
||||||
|
|
||||||
| Variable | Line | Shape | Size | Notes |
|
|
||||||
|----------|------|-------|------|-------|
|
|
||||||
| `q.reshape()` | 526 | `[seq_len × num_heads, head_dim]` | 0 (view) | Reshape for norm |
|
|
||||||
| `k.reshape()` | 528 | `[seq_len × num_kv_heads, head_dim]` | 0 (view) | Reshape for norm |
|
|
||||||
| RMSNorm intermediates | - | see above | `seq_len × num_heads × head_dim × 4` | Float32 upcasting |
|
|
||||||
|
|
||||||
#### 2.2.4 RoPE (Rotary Position Embedding)
|
|
||||||
|
|
||||||
Location: `rotary_embedding.py:apply_rotary_emb()`
|
|
||||||
|
|
||||||
| Variable | Line | Shape | Size | Notes |
|
|
||||||
|----------|------|-------|------|-------|
|
|
||||||
| `cos_sin` | 44 | `[seq_len, 1, head_dim]` | 0 (view) | View of cached cos_sin |
|
|
||||||
| `cos` | 45 | `[seq_len, 1, head_dim/2]` | 0 (view) | Chunk view |
|
|
||||||
| `sin` | 45 | `[seq_len, 1, head_dim/2]` | 0 (view) | Chunk view |
|
|
||||||
|
|
||||||
**Inside `apply_rotary_emb` for Q** (`rotary_embedding.py:6-14`):
|
|
||||||
| Variable | Shape | Size | Notes |
|
|
||||||
|----------|-------|------|-------|
|
|
||||||
| `x.float()` | `[seq_len, num_heads, head_dim]` | `seq_len × num_heads × head_dim × 4` | Upcast to float32 |
|
|
||||||
| `x1` | `[seq_len, num_heads, head_dim/2]` | 0 (view) | Chunk view |
|
|
||||||
| `x2` | `[seq_len, num_heads, head_dim/2]` | 0 (view) | Chunk view |
|
|
||||||
| `y1 = x1*cos - x2*sin` | `[seq_len, num_heads, head_dim/2]` | `seq_len × num_heads × head_dim/2 × 4` | New tensor |
|
|
||||||
| `y2 = x2*cos + x1*sin` | `[seq_len, num_heads, head_dim/2]` | `seq_len × num_heads × head_dim/2 × 4` | New tensor |
|
|
||||||
| `torch.cat((y1, y2))` | `[seq_len, num_heads, head_dim]` | `seq_len × num_heads × head_dim × 4` | New tensor |
|
|
||||||
| `.to(x.dtype)` | `[seq_len, num_heads, head_dim]` | `seq_len × num_heads × head_dim × dtype_size` | Downcast |
|
|
||||||
|
|
||||||
**Inside `apply_rotary_emb` for K**:
|
|
||||||
| Variable | Shape | Size | Notes |
|
|
||||||
|----------|-------|------|-------|
|
|
||||||
| Same pattern as Q | `[seq_len, num_kv_heads, head_dim]` | Similar, with `num_kv_heads` | |
|
|
||||||
|
|
||||||
**Total RoPE temporary for Q+K**: ~`seq_len × (num_heads + num_kv_heads) × head_dim × 4 × 3` (float32 intermediates)
|
|
||||||
|
|
||||||
#### 2.2.5 FlashAttention
|
|
||||||
|
|
||||||
| Variable | Line | Shape | Size | Notes |
|
|
||||||
|----------|------|-------|------|-------|
|
|
||||||
| `attn_output` | 535 | `[seq_len, num_heads, head_dim]` | `seq_len × num_heads × head_dim × dtype_size` | Attention output |
|
|
||||||
| Internal workspace | - | O(seq_len) | Variable | FlashAttention internal |
|
|
||||||
|
|
||||||
#### 2.2.6 Output Projection
|
|
||||||
|
|
||||||
| Variable | Line | Shape | Size | Notes |
|
|
||||||
|----------|------|-------|------|-------|
|
|
||||||
| `attn_output.view()` | 546 | `[seq_len, q_size]` | 0 (view) | Reshape for o_proj |
|
|
||||||
| `o_proj(attn_output)` | 547 | `[seq_len, hidden_size]` | `seq_len × hidden_size × dtype_size` | O projection output |
|
|
||||||
|
|
||||||
#### 2.2.7 Post-Attention LayerNorm
|
|
||||||
|
|
||||||
Same as input layernorm (2.2.1).
|
|
||||||
|
|
||||||
#### 2.2.8 MLP
|
|
||||||
|
|
||||||
Location: `qwen3.py:Qwen3MLP.forward()`
|
|
||||||
|
|
||||||
| Variable | Line | Shape | Size | Notes |
|
|
||||||
|----------|------|-------|------|-------|
|
|
||||||
| `gate_up` | 117 | `[seq_len, 2 × intermediate_size]` | `seq_len × 2 × intermediate_size × dtype_size` | **LARGEST TEMPORARY!** |
|
|
||||||
| `x, y = chunk()` | activation.py:13 | `[seq_len, intermediate_size]` × 2 | 0 (views) | Chunk views |
|
|
||||||
| `F.silu(x)` | activation.py:14 | `[seq_len, intermediate_size]` | `seq_len × intermediate_size × dtype_size` | SiLU activation |
|
|
||||||
| `silu(x) * y` | activation.py:14 | `[seq_len, intermediate_size]` | `seq_len × intermediate_size × dtype_size` | Gated output |
|
|
||||||
| `down_proj()` | 119 | `[seq_len, hidden_size]` | `seq_len × hidden_size × dtype_size` | MLP output |
|
|
||||||
|
|
||||||
### 2.3 Prefill Memory Summary
|
|
||||||
|
|
||||||
**Peak per-layer temporary memory**:
|
|
||||||
```
|
|
||||||
= qkv + RoPE_temps + attn_output + o_proj + layernorm + MLP_gate_up + MLP_activation
|
|
||||||
≈ seq_len × (qkv_size + (num_heads + num_kv_heads) × head_dim × 4 × 3
|
|
||||||
+ num_heads × head_dim + hidden_size × 2 + 2 × intermediate_size + intermediate_size) × dtype_size
|
|
||||||
```
|
|
||||||
|
|
||||||
**Dominant term**: `seq_len × 2 × intermediate_size × dtype_size` (MLP gate_up)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. Non-Pre-allocated Memory: Decode Phase
|
|
||||||
|
|
||||||
Location: `model_runner.py:run_layerwise_offload_decode()`
|
|
||||||
|
|
||||||
### 3.1 Persistent Tensors
|
|
||||||
|
|
||||||
| Variable | Line | Shape | Size | Notes |
|
|
||||||
|----------|------|-------|------|-------|
|
|
||||||
| `input_ids` | 604 | `[1]` | 8 bytes | Single token |
|
|
||||||
| `positions` | 605 | `[1]` | 8 bytes | Single position |
|
|
||||||
| `cu_seqlens_q` | 631 | `[2]` | 8 bytes | Fixed |
|
|
||||||
| `valid_tokens_per_block` | 613-622 | Python list | negligible | |
|
|
||||||
|
|
||||||
### 3.2 Per-Layer Temporary Tensors
|
|
||||||
|
|
||||||
#### 3.2.1 Views (Zero Additional Memory)
|
|
||||||
|
|
||||||
| Variable | Line | Shape | Notes |
|
|
||||||
|----------|------|-------|-------|
|
|
||||||
| `k_prefill` | 682 | `[prefill_len, num_kv_heads, head_dim]` | View of ring buffer |
|
|
||||||
| `v_prefill` | 682 | `[prefill_len, num_kv_heads, head_dim]` | View of ring buffer |
|
|
||||||
| `k_decode_prev` | 686-687 | `[num_decode_tokens-1, num_kv_heads, head_dim]` | View of decode buffer |
|
|
||||||
| `v_decode_prev` | 686-688 | `[num_decode_tokens-1, num_kv_heads, head_dim]` | View of decode buffer |
|
|
||||||
|
|
||||||
#### 3.2.2 New Allocations
|
|
||||||
|
|
||||||
| Variable | Line | Shape | Size | Notes |
|
|
||||||
|----------|------|-------|------|-------|
|
|
||||||
| `hidden_ln` | 654-657 | `[1, hidden_size]` | `hidden_size × dtype_size` | Tiny |
|
|
||||||
| `qkv` | 660 | `[1, qkv_size]` | `qkv_size × dtype_size` | Tiny |
|
|
||||||
| `q` | 667 | `[1, num_heads, head_dim]` | 0 (view) | |
|
|
||||||
| `k_new` | 668 | `[1, num_kv_heads, head_dim]` | 0 (view) | |
|
|
||||||
| `v_new` | 669 | `[1, num_kv_heads, head_dim]` | 0 (view) | |
|
|
||||||
| **`k_full`** | 689/692 | `[prefill_len + num_decode_tokens, num_kv_heads, head_dim]` | `(prefill_len + num_decode_tokens) × kv_dim × dtype_size` | **torch.cat - NEW ALLOCATION** |
|
|
||||||
| **`v_full`** | 690/693 | `[prefill_len + num_decode_tokens, num_kv_heads, head_dim]` | `(prefill_len + num_decode_tokens) × kv_dim × dtype_size` | **torch.cat - NEW ALLOCATION** |
|
|
||||||
| `cu_seqlens_k` | 710 | `[2]` | 8 bytes | Created per layer |
|
|
||||||
| `attn_output` | 712 | `[1, num_heads, head_dim]` | `num_heads × head_dim × dtype_size` | Tiny |
|
|
||||||
| MLP temps | 728 | `[1, ...]` | negligible | Single token |
|
|
||||||
|
|
||||||
### 3.3 Decode Memory Summary
|
|
||||||
|
|
||||||
**Peak per-layer temporary memory**:
|
|
||||||
```
|
|
||||||
= k_full + v_full + small_tensors
|
|
||||||
≈ 2 × (prefill_len + num_decode_tokens) × num_kv_heads × head_dim × dtype_size
|
|
||||||
≈ 2 × seq_len × kv_dim × dtype_size
|
|
||||||
```
|
|
||||||
|
|
||||||
**Dominant term**: `k_full` and `v_full` from `torch.cat()`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. Memory Comparison Table
|
|
||||||
|
|
||||||
For Qwen3-4B with 128k context:
|
|
||||||
|
|
||||||
| Category | Memory | Notes |
|
|
||||||
|----------|--------|-------|
|
|
||||||
| **Pre-allocated GPU** | ~2.2 GB | Ring buffer + decode buffer |
|
|
||||||
| **Pre-allocated CPU** | ~18.4 GB | Pinned memory |
|
|
||||||
| **Model Weights** | ~8 GB | |
|
|
||||||
| **Prefill Peak Temp** | ~10-12 GB | MLP gate_up dominant |
|
|
||||||
| **Decode Peak Temp** | ~512 MB | k_full + v_full |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. Optimization Opportunities
|
|
||||||
|
|
||||||
### 5.1 Decode: Pre-allocate k_full/v_full
|
|
||||||
|
|
||||||
**Current** (L689-693):
|
|
||||||
```python
|
|
||||||
k_full = torch.cat([k_prefill, k_decode_prev, k_new], dim=0) # New allocation each layer
|
|
||||||
v_full = torch.cat([v_prefill, v_decode_prev, v_new], dim=0) # New allocation each layer
|
|
||||||
```
|
|
||||||
|
|
||||||
**Optimized**:
|
|
||||||
```python
|
|
||||||
# Pre-allocate in OffloadEngine.__init__():
|
|
||||||
self.k_full_buffer = torch.zeros(max_seq_len + block_size, num_kv_heads, head_dim, ...)
|
|
||||||
self.v_full_buffer = torch.zeros(max_seq_len + block_size, num_kv_heads, head_dim, ...)
|
|
||||||
|
|
||||||
# In decode loop:
|
|
||||||
total_len = prefill_len + num_decode_tokens
|
|
||||||
k_full = self.k_full_buffer[:total_len]
|
|
||||||
k_full[:prefill_len].copy_(k_prefill)
|
|
||||||
k_full[prefill_len:prefill_len+num_decode_prev].copy_(k_decode_prev)
|
|
||||||
k_full[-1:].copy_(k_new)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Savings**: ~512 MB per decode step (for 128k)
|
|
||||||
|
|
||||||
### 5.2 Decode: Reuse cu_seqlens_k
|
|
||||||
|
|
||||||
**Current** (L710):
|
|
||||||
```python
|
|
||||||
cu_seqlens_k = torch.tensor([0, total_kv_tokens], dtype=torch.int32, device="cuda")
|
|
||||||
```
|
|
||||||
|
|
||||||
**Optimized**:
|
|
||||||
```python
|
|
||||||
# Pre-allocate once:
|
|
||||||
self.cu_seqlens_k = torch.zeros(2, dtype=torch.int32, device="cuda")
|
|
||||||
|
|
||||||
# In decode loop:
|
|
||||||
self.cu_seqlens_k[1] = total_kv_tokens
|
|
||||||
```
|
|
||||||
|
|
||||||
**Savings**: Negligible memory, but reduces allocation overhead.
|
|
||||||
|
|
||||||
### 5.3 RoPE: In-place or Pre-allocated Buffers
|
|
||||||
|
|
||||||
The RoPE implementation creates multiple float32 intermediate tensors. Options:
|
|
||||||
1. Pre-allocate buffers for Q and K rotary outputs
|
|
||||||
2. Use in-place operations where possible
|
|
||||||
3. Use fused RoPE kernel (e.g., from FlashAttention)
|
|
||||||
|
|
||||||
**Potential savings**: ~1.5 GB during prefill per layer
|
|
||||||
|
|
||||||
### 5.4 MLP: Cannot Optimize Easily
|
|
||||||
|
|
||||||
The MLP `gate_up` tensor is inherently required for the gated activation:
|
|
||||||
```python
|
|
||||||
gate_up = gate_up_proj(x) # [seq_len, 2 × intermediate_size]
|
|
||||||
x, y = gate_up.chunk(2, -1)
|
|
||||||
output = silu(x) * y
|
|
||||||
```
|
|
||||||
|
|
||||||
This is a fundamental computation pattern. Potential optimizations:
|
|
||||||
- Chunked MLP computation (process seq_len in chunks)
|
|
||||||
- Fused kernels that avoid materializing full gate_up
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. Memory Flow Diagram
|
|
||||||
|
|
||||||
### Prefill (per layer):
|
|
||||||
|
|
||||||
```
|
|
||||||
hidden_states ──┬──► LayerNorm ──► hidden_ln
|
|
||||||
│
|
|
||||||
residual ◄──────┘
|
|
||||||
|
|
||||||
hidden_ln ──► QKV_proj ──► qkv ──┬──► q ──► Q_norm ──► RoPE ──► q_rotated
|
|
||||||
├──► k ──► K_norm ──► RoPE ──► k_rotated
|
|
||||||
└──► v
|
|
||||||
|
|
||||||
q_rotated, k_rotated, v ──► FlashAttention ──► attn_output
|
|
||||||
|
|
||||||
attn_output ──► O_proj ──► hidden_states'
|
|
||||||
|
|
||||||
hidden_states', residual ──► LayerNorm ──► hidden_ln', residual'
|
|
||||||
|
|
||||||
hidden_ln' ──► MLP_gate_up ──► gate_up ──► SiLU×gate ──► MLP_down ──► hidden_states''
|
|
||||||
|
|
||||||
k_rotated, v ──► CPU_offload (sync copy)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Decode (per layer):
|
|
||||||
|
|
||||||
```
|
|
||||||
[CPU] k_cache_cpu, v_cache_cpu
|
|
||||||
│
|
|
||||||
▼ (H2D async to ring buffer)
|
|
||||||
[GPU] layer_k_cache[buffer_idx], layer_v_cache[buffer_idx]
|
|
||||||
│
|
|
||||||
▼ (view)
|
|
||||||
k_prefill, v_prefill
|
|
||||||
│
|
|
||||||
├──► torch.cat([k_prefill, k_decode_prev, k_new]) ──► k_full ⚠️ NEW ALLOC
|
|
||||||
│
|
|
||||||
└──► torch.cat([v_prefill, v_decode_prev, v_new]) ──► v_full ⚠️ NEW ALLOC
|
|
||||||
|
|
||||||
q_new, k_full, v_full ──► FlashAttention ──► attn_output
|
|
||||||
|
|
||||||
k_new, v_new ──► decode_k_buffer, decode_v_buffer (in-place store)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 7. Appendix: Size Calculations
|
|
||||||
|
|
||||||
### Qwen3-4B Example (128k context)
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Model config
|
|
||||||
seq_len = 131072
|
|
||||||
hidden_size = 2560
|
|
||||||
num_heads = 20
|
|
||||||
num_kv_heads = 8
|
|
||||||
head_dim = 128
|
|
||||||
intermediate_size = 13696
|
|
||||||
num_layers = 36
|
|
||||||
block_size = 1024
|
|
||||||
num_kv_buffers = 4
|
|
||||||
num_cpu_blocks = 128
|
|
||||||
dtype_size = 2 # fp16/bf16
|
|
||||||
|
|
||||||
# Derived
|
|
||||||
kv_dim = num_kv_heads * head_dim # 1024
|
|
||||||
q_size = num_heads * head_dim # 2560
|
|
||||||
qkv_size = q_size + 2 * kv_dim # 4608
|
|
||||||
|
|
||||||
# Pre-allocated GPU (OffloadEngine)
|
|
||||||
ring_buffer = 2 * num_kv_buffers * seq_len * kv_dim * dtype_size
|
|
||||||
# = 2 * 4 * 131072 * 1024 * 2 = 2,147,483,648 bytes = 2048 MB
|
|
||||||
|
|
||||||
decode_buffer = 2 * num_layers * block_size * kv_dim * dtype_size
|
|
||||||
# = 2 * 36 * 1024 * 1024 * 2 = 150,994,944 bytes = 144 MB
|
|
||||||
|
|
||||||
# Pre-allocated CPU
|
|
||||||
cpu_cache = 2 * num_layers * num_cpu_blocks * block_size * kv_dim * dtype_size
|
|
||||||
# = 2 * 36 * 128 * 1024 * 1024 * 2 = 19,327,352,832 bytes = 18432 MB
|
|
||||||
|
|
||||||
# Prefill temporaries (per layer peak)
|
|
||||||
mlp_gate_up = seq_len * 2 * intermediate_size * dtype_size
|
|
||||||
# = 131072 * 2 * 13696 * 2 = 7,180,648,448 bytes = 6848 MB
|
|
||||||
|
|
||||||
# Decode temporaries (per layer)
|
|
||||||
k_full = seq_len * kv_dim * dtype_size
|
|
||||||
# = 131072 * 1024 * 2 = 268,435,456 bytes = 256 MB
|
|
||||||
v_full = k_full # = 256 MB
|
|
||||||
# Total: 512 MB
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 8. Empirical Validation
|
|
||||||
|
|
||||||
This section validates the theoretical memory analysis against actual measurements.
|
|
||||||
|
|
||||||
### 8.1 Test Configuration
|
|
||||||
|
|
||||||
```bash
|
|
||||||
python tests/test_needle.py --enable-offload --input-len 100000 --block-size 1024
|
|
||||||
```
|
|
||||||
|
|
||||||
**Parameters:**
|
|
||||||
- Model: Qwen3-4B-Instruct
|
|
||||||
- `seq_len = 100000` (actual tokens: 99925)
|
|
||||||
- `block_size = 1024`
|
|
||||||
- `max_model_len = 131072`
|
|
||||||
- `num_kv_buffers = 4`
|
|
||||||
|
|
||||||
### 8.2 Theoretical Peak Memory Calculation
|
|
||||||
|
|
||||||
#### Step 1: Model Load Memory
|
|
||||||
|
|
||||||
| Component | Formula | Size |
|
|
||||||
|-----------|---------|------|
|
|
||||||
| Model weights | ~4B params × 2 bytes | ~8 GB |
|
|
||||||
| Ring buffer | 2 × 4 × 131072 × 1024 × 2 | 2048 MB |
|
|
||||||
| Decode buffer | 2 × 36 × 1024 × 1024 × 2 | 144 MB |
|
|
||||||
| **Subtotal** | | **~10.2 GB** |
|
|
||||||
|
|
||||||
#### Step 2: Prefill Activation Peak (per-layer)
|
|
||||||
|
|
||||||
| Component | Formula | Size |
|
|
||||||
|-----------|---------|------|
|
|
||||||
| hidden_states | 100000 × 2560 × 2 | 512 MB |
|
|
||||||
| residual | 100000 × 2560 × 2 | 512 MB |
|
|
||||||
| MLP gate_up | 100000 × 27392 × 2 | **5478 MB** |
|
|
||||||
| MLP silu×gate | 100000 × 13696 × 2 | 2739 MB |
|
|
||||||
| Other intermediates (qkv, RoPE, attn) | ~1-2 GB | ~1500 MB |
|
|
||||||
| **Subtotal** | | **~10 GB** |
|
|
||||||
|
|
||||||
#### Step 3: Total Peak
|
|
||||||
|
|
||||||
```
|
|
||||||
Total Peak = Model Load + Activation Peak
|
|
||||||
= 10.2 GB + 10 GB
|
|
||||||
= ~20.2 GB
|
|
||||||
```
|
|
||||||
|
|
||||||
### 8.3 Actual Measurement Results
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
# ... run inference ...
|
|
||||||
peak = torch.cuda.max_memory_allocated()
|
|
||||||
```
|
|
||||||
|
|
||||||
| Metric | Value |
|
|
||||||
|--------|-------|
|
|
||||||
| After model load | 9.82 GB |
|
|
||||||
| Peak during inference | **20.02 GB** |
|
|
||||||
| Activation peak (delta) | 10.20 GB |
|
|
||||||
|
|
||||||
### 8.4 Comparison: Theory vs Actual
|
|
||||||
|
|
||||||
| Component | Theoretical | Actual | Error |
|
|
||||||
|-----------|-------------|--------|-------|
|
|
||||||
| Model load memory | ~10.2 GB | 9.82 GB | -3.7% |
|
|
||||||
| Activation peak | ~10 GB | 10.20 GB | +2.0% |
|
|
||||||
| **Total peak** | **~20.2 GB** | **20.02 GB** | **< 1%** |
|
|
||||||
|
|
||||||
### 8.5 Key Findings
|
|
||||||
|
|
||||||
1. **Theoretical model is accurate**: < 5% error in all components.
|
|
||||||
|
|
||||||
2. **MLP gate_up is the dominant temporary**:
|
|
||||||
- Size: 5.35 GB (for 100k tokens)
|
|
||||||
- Accounts for ~50% of activation peak
|
|
||||||
- Formula: `seq_len × 2 × intermediate_size × dtype_size`
|
|
||||||
|
|
||||||
3. **Memory scaling with sequence length**:
|
|
||||||
| seq_len | Model Load | Activation Peak | Total Peak |
|
|
||||||
|---------|------------|-----------------|------------|
|
|
||||||
| 8k | ~10 GB | ~0.8 GB | ~11 GB |
|
|
||||||
| 32k | ~10 GB | ~3.2 GB | ~13 GB |
|
|
||||||
| 64k | ~10 GB | ~6.4 GB | ~16 GB |
|
|
||||||
| 100k | ~10 GB | ~10 GB | ~20 GB |
|
|
||||||
| 128k | ~10 GB | ~13 GB | ~23 GB |
|
|
||||||
|
|
||||||
4. **Decode memory is much smaller**:
|
|
||||||
- Per-step: ~512 MB for k_full + v_full (at 100k context)
|
|
||||||
- Does not grow with decode steps (constant per layer)
|
|
||||||
|
|
||||||
### 8.6 Memory Profiling Script
|
|
||||||
|
|
||||||
To reproduce the measurement:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from tests.utils import generate_needle_prompt
|
|
||||||
|
|
||||||
# Reset memory stats
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Initialize LLM
|
|
||||||
llm = LLM(
|
|
||||||
"path/to/model",
|
|
||||||
enforce_eager=True,
|
|
||||||
max_model_len=131072,
|
|
||||||
max_num_batched_tokens=131072,
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
kvcache_block_size=1024,
|
|
||||||
num_gpu_blocks=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
after_load = torch.cuda.memory_allocated()
|
|
||||||
print(f"After model load: {after_load / 1024**3:.2f} GB")
|
|
||||||
|
|
||||||
# Generate prompt and run inference
|
|
||||||
prompt, expected = generate_needle_prompt(
|
|
||||||
tokenizer=llm.tokenizer,
|
|
||||||
target_length=100000,
|
|
||||||
needle_position=0.5,
|
|
||||||
)
|
|
||||||
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
outputs = llm.generate([prompt], SamplingParams(max_tokens=32))
|
|
||||||
|
|
||||||
peak = torch.cuda.max_memory_allocated()
|
|
||||||
print(f"Peak during inference: {peak / 1024**3:.2f} GB")
|
|
||||||
```
|
|
||||||
@@ -1,233 +0,0 @@
|
|||||||
# Multi-Model Support
|
|
||||||
|
|
||||||
本文档描述 nanovllm 的多模型支持架构,以及如何添加新模型。
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
nanovllm 通过模型注册表 (Model Registry) 机制支持多种模型架构。系统根据 HuggingFace config 中的 `architectures` 字段自动选择对应的模型实现。
|
|
||||||
|
|
||||||
### 当前支持的模型
|
|
||||||
|
|
||||||
| 架构 | 模型示例 | 文件 |
|
|
||||||
|------|---------|------|
|
|
||||||
| `Qwen3ForCausalLM` | Qwen3-0.6B, Qwen3-4B | `nanovllm/models/qwen3.py` |
|
|
||||||
| `Qwen2ForCausalLM` | Qwen2.5-7B | `nanovllm/models/qwen3.py` |
|
|
||||||
| `LlamaForCausalLM` | Llama-3.1-8B-Instruct | `nanovllm/models/llama.py` |
|
|
||||||
|
|
||||||
## 架构设计
|
|
||||||
|
|
||||||
### 模型注册表
|
|
||||||
|
|
||||||
```
|
|
||||||
nanovllm/models/
|
|
||||||
├── __init__.py # 导出 get_model_class, 导入所有模型
|
|
||||||
├── registry.py # 注册表核心: MODEL_REGISTRY, @register_model
|
|
||||||
├── qwen3.py # Qwen3/Qwen2 实现
|
|
||||||
└── llama.py # Llama 实现
|
|
||||||
```
|
|
||||||
|
|
||||||
### 动态模型加载流程
|
|
||||||
|
|
||||||
```
|
|
||||||
LLM(model_path)
|
|
||||||
→ Config.__post_init__()
|
|
||||||
→ hf_config = AutoConfig.from_pretrained(model_path)
|
|
||||||
→ ModelRunner.__init__()
|
|
||||||
→ model_class = get_model_class(hf_config) # 根据 architectures 选择
|
|
||||||
→ model = model_class(hf_config)
|
|
||||||
→ load_model(model, model_path)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 添加新模型
|
|
||||||
|
|
||||||
### 步骤 1: 创建模型文件
|
|
||||||
|
|
||||||
在 `nanovllm/models/` 下创建新文件,例如 `mistral.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
from nanovllm.layers.activation import SiluAndMul
|
|
||||||
from nanovllm.layers.attention import Attention
|
|
||||||
from nanovllm.layers.layernorm import RMSNorm
|
|
||||||
from nanovllm.layers.linear import QKVParallelLinear, MergedColumnParallelLinear, RowParallelLinear
|
|
||||||
from nanovllm.layers.rotary_embedding import get_rope
|
|
||||||
from nanovllm.layers.embed_head import VocabParallelEmbedding, ParallelLMHead
|
|
||||||
from nanovllm.models.registry import register_model
|
|
||||||
|
|
||||||
|
|
||||||
class MistralAttention(nn.Module):
|
|
||||||
def __init__(self, ...):
|
|
||||||
# 实现注意力层
|
|
||||||
pass
|
|
||||||
|
|
||||||
class MistralMLP(nn.Module):
|
|
||||||
def __init__(self, ...):
|
|
||||||
# 实现 MLP 层
|
|
||||||
pass
|
|
||||||
|
|
||||||
class MistralDecoderLayer(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
# 组合 Attention + MLP
|
|
||||||
pass
|
|
||||||
|
|
||||||
class MistralModel(nn.Module):
|
|
||||||
def __init__(self, config):
|
|
||||||
# Embedding + Layers + Norm
|
|
||||||
pass
|
|
||||||
|
|
||||||
@register_model("MistralForCausalLM")
|
|
||||||
class MistralForCausalLM(nn.Module):
|
|
||||||
# 权重映射 (HF 权重名 -> nanovllm 权重名)
|
|
||||||
packed_modules_mapping = {
|
|
||||||
"q_proj": ("qkv_proj", "q"),
|
|
||||||
"k_proj": ("qkv_proj", "k"),
|
|
||||||
"v_proj": ("qkv_proj", "v"),
|
|
||||||
"gate_proj": ("gate_up_proj", 0),
|
|
||||||
"up_proj": ("gate_up_proj", 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.model = MistralModel(config)
|
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
|
||||||
|
|
||||||
def forward(self, input_ids, positions):
|
|
||||||
return self.model(input_ids, positions)
|
|
||||||
|
|
||||||
def compute_logits(self, hidden_states):
|
|
||||||
return self.lm_head(hidden_states)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 步骤 2: 注册模型
|
|
||||||
|
|
||||||
在 `nanovllm/models/__init__.py` 中导入新模型:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from nanovllm.models import mistral # 添加这行
|
|
||||||
```
|
|
||||||
|
|
||||||
### 步骤 3: 处理特殊配置
|
|
||||||
|
|
||||||
如果模型有特殊的 RoPE scaling 或其他配置,需要在相应的 layer 中添加支持。
|
|
||||||
|
|
||||||
## 模型架构差异
|
|
||||||
|
|
||||||
### Qwen3 vs Llama
|
|
||||||
|
|
||||||
| 特性 | Qwen3 | Llama |
|
|
||||||
|------|-------|-------|
|
|
||||||
| QKV Bias | 可配置 (`attention_bias`) | 无 |
|
|
||||||
| Q/K Norm | 有 (RMSNorm, 当 bias=False) | 无 |
|
|
||||||
| MLP Bias | 无 | 无 |
|
|
||||||
| RoPE Scaling | 无 | llama3 类型 |
|
|
||||||
| RoPE Theta | 1,000,000 | 500,000 |
|
|
||||||
|
|
||||||
### RoPE Scaling 支持
|
|
||||||
|
|
||||||
目前支持的 RoPE 类型:
|
|
||||||
|
|
||||||
| `rope_type` | 说明 | 模型 |
|
|
||||||
|-------------|------|------|
|
|
||||||
| `None` | 标准 RoPE | Qwen3 |
|
|
||||||
| `llama3` | Llama 3 频率缩放 | Llama 3.1 |
|
|
||||||
|
|
||||||
Llama3 RoPE 特点:
|
|
||||||
- 低频分量 (长距离依赖): 缩放 1/factor
|
|
||||||
- 高频分量 (短距离依赖): 保持不变
|
|
||||||
- 中频分量: 平滑插值
|
|
||||||
|
|
||||||
## 权重加载
|
|
||||||
|
|
||||||
### packed_modules_mapping
|
|
||||||
|
|
||||||
nanovllm 将多个 HuggingFace 权重合并到单个张量中以提高效率:
|
|
||||||
|
|
||||||
```python
|
|
||||||
packed_modules_mapping = {
|
|
||||||
# HF 权重名: (nanovllm 权重名, shard_id)
|
|
||||||
"q_proj": ("qkv_proj", "q"), # Q 投影 -> QKV 合并
|
|
||||||
"k_proj": ("qkv_proj", "k"), # K 投影 -> QKV 合并
|
|
||||||
"v_proj": ("qkv_proj", "v"), # V 投影 -> QKV 合并
|
|
||||||
"gate_proj": ("gate_up_proj", 0), # Gate -> Gate+Up 合并
|
|
||||||
"up_proj": ("gate_up_proj", 1), # Up -> Gate+Up 合并
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
### 权重加载流程
|
|
||||||
|
|
||||||
```python
|
|
||||||
# nanovllm/utils/loader.py
|
|
||||||
def load_model(model, path):
|
|
||||||
for file in glob(path + "/*.safetensors"):
|
|
||||||
with safe_open(file) as f:
|
|
||||||
for weight_name in f.keys():
|
|
||||||
# 检查是否需要映射
|
|
||||||
if weight_name in packed_modules_mapping:
|
|
||||||
# 使用自定义 weight_loader
|
|
||||||
param.weight_loader(param, tensor, shard_id)
|
|
||||||
else:
|
|
||||||
# 直接复制
|
|
||||||
param.data.copy_(tensor)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 测试验证
|
|
||||||
|
|
||||||
### Needle-in-Haystack 测试
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Llama 3.1 (32K, offload 模式)
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python tests/test_needle.py \
|
|
||||||
--model ~/models/Llama-3.1-8B-Instruct \
|
|
||||||
--max-model-len 40960 \
|
|
||||||
--input-len 32768 \
|
|
||||||
--block-size 1024 \
|
|
||||||
--num-gpu-blocks 4 \
|
|
||||||
--enable-offload
|
|
||||||
|
|
||||||
# Qwen3 (8K, offload 模式)
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python tests/test_needle.py \
|
|
||||||
--model ~/models/Qwen3-4B-Instruct-2507 \
|
|
||||||
--max-model-len 40960 \
|
|
||||||
--input-len 8192 \
|
|
||||||
--enable-offload
|
|
||||||
```
|
|
||||||
|
|
||||||
### 测试结果
|
|
||||||
|
|
||||||
| 模型 | 输入长度 | Needle 位置 | 结果 |
|
|
||||||
|------|---------|-------------|------|
|
|
||||||
| Llama-3.1-8B | 32K | 50% | ✅ PASSED |
|
|
||||||
| Llama-3.1-8B | 32K | 90% | ✅ PASSED |
|
|
||||||
| Llama-3.1-8B | 32K | 10% | ❌ FAILED (Lost in Middle) |
|
|
||||||
| Qwen3-4B | 8K | 50% | ✅ PASSED |
|
|
||||||
|
|
||||||
## 文件结构
|
|
||||||
|
|
||||||
```
|
|
||||||
nanovllm/
|
|
||||||
├── models/
|
|
||||||
│ ├── __init__.py # 模型导出和导入
|
|
||||||
│ ├── registry.py # 注册表实现
|
|
||||||
│ ├── qwen3.py # Qwen3/Qwen2 模型
|
|
||||||
│ └── llama.py # Llama 模型
|
|
||||||
├── layers/
|
|
||||||
│ ├── rotary_embedding.py # RoPE (含 Llama3 scaling)
|
|
||||||
│ ├── attention.py # FlashAttention wrapper
|
|
||||||
│ ├── linear.py # 并行 Linear 层
|
|
||||||
│ └── ...
|
|
||||||
└── engine/
|
|
||||||
└── model_runner.py # 动态模型加载
|
|
||||||
```
|
|
||||||
|
|
||||||
## 注意事项
|
|
||||||
|
|
||||||
1. **Tokenizer 差异**: 不同模型的 tokenizer 分词策略不同,例如 Llama 将 "7492" 分为 2 tokens,Qwen3 分为 4 tokens。
|
|
||||||
|
|
||||||
2. **RoPE Scaling**: 如果模型使用非标准 RoPE,需要在 `rotary_embedding.py` 中添加支持。
|
|
||||||
|
|
||||||
3. **CPU Offload**: 在 3090 等显存有限的 GPU 上,使用 `--enable-offload` 进行长上下文测试。
|
|
||||||
|
|
||||||
4. **Lost in Middle**: LLM 对开头信息的记忆能力较弱,这是模型本身的限制,不是实现问题。
|
|
||||||
@@ -1,306 +0,0 @@
|
|||||||
# CPU Offload Accuracy Issue Investigation
|
|
||||||
|
|
||||||
## Problem Summary
|
|
||||||
|
|
||||||
**UPDATE (2026-01-12)**: Single request inference works correctly! The issue is with batch/sequential request handling.
|
|
||||||
|
|
||||||
| Mode | Testing Method | Accuracy |
|
|
||||||
|------|----------------|----------|
|
|
||||||
| **CPU Offload** | **Independent** (1 request per process) | **100%** ✓ |
|
|
||||||
| **CPU Offload** | Batch (multiple requests per process) | 66% ✗ |
|
|
||||||
| **Non-Offload** | Batch | 100% ✓ |
|
|
||||||
|
|
||||||
**Conclusion**: The offload implementation is correct for single requests. The bug is in state cleanup between sequential requests within the same process.
|
|
||||||
|
|
||||||
## Test Environment
|
|
||||||
|
|
||||||
- **Model**: Llama-3.1-8B-Instruct
|
|
||||||
- **Task**: RULER NIAH (Needle-In-A-Haystack) 32K context
|
|
||||||
- **GPU**: NVIDIA A100-SXM4-80GB
|
|
||||||
- **Data**: `tests/data/ruler_niah/niah_single_1_32k.jsonl` (100 samples)
|
|
||||||
|
|
||||||
## Reproduction Commands
|
|
||||||
|
|
||||||
### Non-Offload Mode (100% accuracy)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler_niah.py \
|
|
||||||
--model ~/models/Llama-3.1-8B-Instruct \
|
|
||||||
--gpu-utilization 0.7 \
|
|
||||||
--quiet
|
|
||||||
```
|
|
||||||
|
|
||||||
**Configuration**:
|
|
||||||
- KV Cache: GPU only, 51 blocks (6528 MB)
|
|
||||||
- Block size: 1024 tokens
|
|
||||||
|
|
||||||
### Offload Mode (66% accuracy)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler_niah.py \
|
|
||||||
--model ~/models/Llama-3.1-8B-Instruct \
|
|
||||||
--enable-offload \
|
|
||||||
--quiet
|
|
||||||
```
|
|
||||||
|
|
||||||
**Configuration**:
|
|
||||||
- KV Cache: GPU 4 blocks (512 MB) + CPU 32 blocks (4096 MB)
|
|
||||||
- Ring buffer: 4 buffers × 33280 tokens (520 MB)
|
|
||||||
- Per-layer decode buffer: 128 MB
|
|
||||||
- Block size: 1024 tokens
|
|
||||||
|
|
||||||
## Observed Failure Patterns
|
|
||||||
|
|
||||||
From the 5-sample verbose test:
|
|
||||||
|
|
||||||
| Sample | Expected | Offload Output | Status |
|
|
||||||
|--------|----------|----------------|--------|
|
|
||||||
| 0 | 8930103 | `: 8930103.` | PASS |
|
|
||||||
| 1 | 4194548 | `: 419 multiplication of 4548.` | **FAIL** |
|
|
||||||
| 2 | 8231838 | `:ное 8231838.` | PASS |
|
|
||||||
| 3 | 8835373 | `: 8835373.` | PASS |
|
|
||||||
| 4 | 7754864 | `aster 7754864.` | PASS |
|
|
||||||
|
|
||||||
**Failure pattern**: The model sometimes produces corrupted or split outputs (e.g., "419 multiplication of 4548" instead of "4194548").
|
|
||||||
|
|
||||||
## Architecture Overview
|
|
||||||
|
|
||||||
### Offload Mode Data Flow
|
|
||||||
|
|
||||||
```
|
|
||||||
Prefill Phase:
|
|
||||||
1. Input tokens → chunked into 2048-token chunks
|
|
||||||
2. Each chunk processed layer by layer:
|
|
||||||
- Load KV from CPU → GPU ring buffer
|
|
||||||
- Compute attention
|
|
||||||
- Store KV back to CPU
|
|
||||||
3. Ring buffer holds recent KV for decode
|
|
||||||
|
|
||||||
Decode Phase:
|
|
||||||
1. For each new token:
|
|
||||||
- Load all layer KV from CPU (one layer at a time)
|
|
||||||
- Compute attention against full context
|
|
||||||
- Generate next token
|
|
||||||
```
|
|
||||||
|
|
||||||
### Key Components
|
|
||||||
|
|
||||||
| File | Component | Description |
|
|
||||||
|------|-----------|-------------|
|
|
||||||
| `nanovllm/kvcache/offload_engine.py` | `OffloadEngine` | Manages CPU↔GPU KV cache transfers |
|
|
||||||
| `nanovllm/kvcache/offload_engine.py` | `RingKVBuffer` | GPU ring buffer for recent KV |
|
|
||||||
| `nanovllm/engine/model_runner.py` | `run_chunked_offload_prefill()` | Chunked prefill with offload |
|
|
||||||
| `nanovllm/engine/model_runner.py` | `run_offload_decode()` | Layer-wise decode with offload |
|
|
||||||
| `nanovllm/kvcache/hybrid_manager.py` | `HybridBlockManager` | CPU block allocation |
|
|
||||||
|
|
||||||
## Potential Root Causes
|
|
||||||
|
|
||||||
### 1. Ring Buffer Index/Position Issues
|
|
||||||
|
|
||||||
**Location**: `nanovllm/kvcache/offload_engine.py`
|
|
||||||
|
|
||||||
The ring buffer uses modular indexing. Potential issues:
|
|
||||||
- Position calculation errors during prefill/decode transition
|
|
||||||
- Off-by-one errors in KV storage/retrieval
|
|
||||||
- Incorrect handling when sequence length approaches `max_seq_len`
|
|
||||||
|
|
||||||
**Recent fix applied**: `max_seq_len = max_model_len + 512` to prevent overflow, but there may be other indexing issues.
|
|
||||||
|
|
||||||
### 2. Chunked Prefill KV Storage
|
|
||||||
|
|
||||||
**Location**: `nanovllm/engine/model_runner.py:run_chunked_offload_prefill()`
|
|
||||||
|
|
||||||
During chunked prefill:
|
|
||||||
- KV computed for chunk N must be correctly stored before processing chunk N+1
|
|
||||||
- Position IDs must be correctly accumulated across chunks
|
|
||||||
- CPU block allocation must be contiguous and correctly tracked
|
|
||||||
|
|
||||||
**Suspect areas**:
|
|
||||||
```python
|
|
||||||
# Check if positions are correctly tracked across chunks
|
|
||||||
# Check if KV is correctly copied to CPU after each chunk
|
|
||||||
# Check if ring buffer indices align with CPU block indices
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Decode Phase KV Loading
|
|
||||||
|
|
||||||
**Location**: `nanovllm/engine/model_runner.py:run_offload_decode()`
|
|
||||||
|
|
||||||
During decode:
|
|
||||||
- Must load KV for ALL previous tokens (both prefill and decode)
|
|
||||||
- Layer-by-layer loading must be synchronized correctly
|
|
||||||
- Attention computation must use correct sequence length
|
|
||||||
|
|
||||||
**Suspect areas**:
|
|
||||||
```python
|
|
||||||
# Check if decode loads KV for full context length
|
|
||||||
# Check if new decode KV is stored correctly
|
|
||||||
# Check if attention mask/positions are correct
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. CPU↔GPU Transfer Synchronization
|
|
||||||
|
|
||||||
**Location**: `nanovllm/kvcache/offload_engine.py`
|
|
||||||
|
|
||||||
CUDA streams and synchronization:
|
|
||||||
- Async copies may complete out of order
|
|
||||||
- Missing synchronization points could cause stale data
|
|
||||||
- Stream priorities may affect correctness
|
|
||||||
|
|
||||||
### 5. Numerical Precision
|
|
||||||
|
|
||||||
- CPU tensors use float16/bfloat16
|
|
||||||
- GPU computation precision
|
|
||||||
- Potential precision loss during transfers
|
|
||||||
|
|
||||||
## Debugging Strategy
|
|
||||||
|
|
||||||
### Step 1: Identify Failing Samples
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run verbose mode to see which samples fail
|
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler_niah.py \
|
|
||||||
--model ~/models/Llama-3.1-8B-Instruct \
|
|
||||||
--enable-offload \
|
|
||||||
--verbose 2>&1 | tee offload_verbose.log
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 2: Compare Token-by-Token
|
|
||||||
|
|
||||||
Create a debug script to compare token generation between offload and non-offload modes for a failing sample:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Compare logits at each decode step
|
|
||||||
# Check if divergence starts at a specific position
|
|
||||||
# Log KV cache contents at divergence point
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 3: Verify KV Cache Contents
|
|
||||||
|
|
||||||
Add debugging to `OffloadEngine`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# In store_kv(): Log what's being stored
|
|
||||||
# In load_kv(): Log what's being loaded
|
|
||||||
# Compare loaded KV with expected values
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 4: Check Position/Index Calculations
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Log ring buffer write/read positions
|
|
||||||
# Log CPU block indices
|
|
||||||
# Verify position IDs match actual token positions
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 5: Isolate the Bug
|
|
||||||
|
|
||||||
1. Test with shorter sequences (16K, 8K) to see if issue is length-dependent
|
|
||||||
2. Test with single chunk (no chunking) to isolate chunked prefill
|
|
||||||
3. Test prefill-only (no decode) to isolate decode phase
|
|
||||||
|
|
||||||
## Quick Debugging Commands
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Test single failing sample with verbose output
|
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler_niah.py \
|
|
||||||
--model ~/models/Llama-3.1-8B-Instruct \
|
|
||||||
--enable-offload \
|
|
||||||
--sample-indices 1 \
|
|
||||||
--verbose
|
|
||||||
|
|
||||||
# Test with different context lengths
|
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=.:$PYTHONPATH python tests/test_ruler_niah.py \
|
|
||||||
--model ~/models/Llama-3.1-8B-Instruct \
|
|
||||||
--enable-offload \
|
|
||||||
--max-model-len 16384 \
|
|
||||||
--verbose
|
|
||||||
```
|
|
||||||
|
|
||||||
## Related Documentation
|
|
||||||
|
|
||||||
- [`docs/ruler_niah_standalone_test.md`](ruler_niah_standalone_test.md) - Test setup and background
|
|
||||||
- [`docs/layerwise_offload_memory_analysis.md`](layerwise_offload_memory_analysis.md) - Memory analysis (if exists)
|
|
||||||
|
|
||||||
## Test Results Log
|
|
||||||
|
|
||||||
### 2026-01-12 (Updated - Independent Testing)
|
|
||||||
|
|
||||||
**Key Finding**: When each sample is tested independently (separate Python process per sample), CPU offload achieves **100% accuracy**.
|
|
||||||
|
|
||||||
| Test | Mode | Testing Method | Samples | Passed | Accuracy |
|
|
||||||
|------|------|----------------|---------|--------|----------|
|
|
||||||
| RULER NIAH 32K | CPU Offload | **Independent** (separate process) | 100 | 100 | **100%** |
|
|
||||||
| RULER NIAH 32K | CPU Offload | Batch (single process) | 100 | 66 | 66% |
|
|
||||||
| RULER NIAH 32K | Non-Offload | Batch (single process) | 100 | 100 | 100% |
|
|
||||||
|
|
||||||
**Test Configuration (Independent Mode)**:
|
|
||||||
- GPUs: 4x RTX 3090 (parallel testing)
|
|
||||||
- Each sample: Fresh Python process with new LLM instance
|
|
||||||
- Port: Each GPU uses unique port (2333+gpu_id)
|
|
||||||
- Duration: 17.9 minutes for 100 samples
|
|
||||||
- Throughput: 5.58 samples/min
|
|
||||||
|
|
||||||
### 2025-01-12 (Original - Batch Testing)
|
|
||||||
|
|
||||||
| Test | Mode | Samples | Passed | Accuracy |
|
|
||||||
|------|------|---------|--------|----------|
|
|
||||||
| RULER NIAH 32K | Non-Offload | 100 | 100 | 100% |
|
|
||||||
| RULER NIAH 32K | CPU Offload | 100 | 66 | 66% |
|
|
||||||
|
|
||||||
## Root Cause Analysis Update
|
|
||||||
|
|
||||||
### Confirmed: Single Request Inference is Correct
|
|
||||||
|
|
||||||
The 100% accuracy in independent testing mode confirms that:
|
|
||||||
1. **Single request inference works correctly** - The offload engine, ring buffer, and chunked prefill are functioning properly for individual requests
|
|
||||||
2. **The bug is in batch/sequential request handling** - State accumulation or incomplete cleanup between requests causes failures
|
|
||||||
|
|
||||||
### Suspected Issue: State Accumulation Between Requests
|
|
||||||
|
|
||||||
When multiple requests are processed in the same Python process:
|
|
||||||
- The first request succeeds (e.g., Sample 0: PASS)
|
|
||||||
- Subsequent requests may fail due to:
|
|
||||||
- Residual state in ring buffer
|
|
||||||
- Incomplete KV cache cleanup
|
|
||||||
- Position tracking errors across requests
|
|
||||||
- CPU block allocation fragmentation
|
|
||||||
|
|
||||||
### Evidence
|
|
||||||
|
|
||||||
From batch mode testing (5 samples):
|
|
||||||
| Sample | Expected | Output | Status |
|
|
||||||
|--------|----------|--------|--------|
|
|
||||||
| 0 | 8930103 | `: 8930103.` | PASS (first request) |
|
|
||||||
| 1 | 4194548 | `: 419 multiplication of 4548.` | **FAIL** (second request) |
|
|
||||||
| 2 | 8231838 | `:ное 8231838.` | PASS |
|
|
||||||
| 3 | 8835373 | `: 8835373.` | PASS |
|
|
||||||
| 4 | 7754864 | `aster 7754864.` | PASS |
|
|
||||||
|
|
||||||
The corrupted output in Sample 1 suggests interference from Sample 0's state.
|
|
||||||
|
|
||||||
## Workaround
|
|
||||||
|
|
||||||
Use independent testing mode (separate process per request) for production evaluation:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Using test_ruler_niah.sh for parallel independent testing
|
|
||||||
./tests/test_ruler_niah.sh --gpus "0,1,2,3" --total 100
|
|
||||||
|
|
||||||
# Or manually run each sample in a separate process
|
|
||||||
for i in $(seq 0 99); do
|
|
||||||
CUDA_VISIBLE_DEVICES=0 python tests/test_ruler_niah.py \
|
|
||||||
--enable-offload --sample-indices $i --quiet
|
|
||||||
done
|
|
||||||
```
|
|
||||||
|
|
||||||
## Next Steps
|
|
||||||
|
|
||||||
1. [x] ~~Identify pattern in failing samples~~ → Pattern: First sample usually passes, failures occur in subsequent samples
|
|
||||||
2. [ ] **Investigate state cleanup between requests in offload mode**
|
|
||||||
- Check `OffloadEngine` reset/cleanup logic
|
|
||||||
- Check ring buffer state between requests
|
|
||||||
- Check CPU block manager cleanup
|
|
||||||
3. [ ] Add `reset()` method to `OffloadEngine` for explicit state cleanup
|
|
||||||
4. [ ] Compare state between first and second request in batch mode
|
|
||||||
5. [ ] Write unit test that reproduces the batch mode failure
|
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
# RULER Benchmark 测试报告
|
|
||||||
|
|
||||||
**测试日期**: 2025-01-14
|
|
||||||
**测试环境**: 6x RTX 3090, CPU Offload 模式
|
|
||||||
**模型**: Llama-3.1-8B-Instruct
|
|
||||||
**上下文长度**: 32K tokens
|
|
||||||
|
|
||||||
## 测试概述
|
|
||||||
|
|
||||||
使用 RULER benchmark 对 nano-vllm 的 CPU offload 模式进行全面的长上下文能力测试。RULER 是 NVIDIA 开发的长上下文评测基准,包含 13 个任务类别。
|
|
||||||
|
|
||||||
## 测试结果
|
|
||||||
|
|
||||||
### 总体结果
|
|
||||||
|
|
||||||
| 类别 | 数据集 | 正确/总数 | 准确率 | 平均分数 |
|
|
||||||
|------|--------|-----------|--------|----------|
|
|
||||||
| **NIAH Single** | niah_single_1 | 100/100 | 100.0% | 1.000 |
|
|
||||||
| | niah_single_2 | 100/100 | 100.0% | 1.000 |
|
|
||||||
| | niah_single_3 | 100/100 | 100.0% | 1.000 |
|
|
||||||
| **NIAH MultiKey** | niah_multikey_1 | 100/100 | 100.0% | 1.000 |
|
|
||||||
| | niah_multikey_2 | 90/100 | 90.0% | 0.900 |
|
|
||||||
| | niah_multikey_3 | 93/100 | 93.0% | 0.930 |
|
|
||||||
| **NIAH Other** | niah_multiquery | 100/100 | 100.0% | 1.000 |
|
|
||||||
| | niah_multivalue | 100/100 | 100.0% | 1.000 |
|
|
||||||
| **QA** | qa_1 | 79/100 | 79.0% | 0.790 |
|
|
||||||
| | qa_2 | 51/100 | 51.0% | 0.510 |
|
|
||||||
| **Aggregation** | cwe | 86/100 | 86.0% | 0.680 |
|
|
||||||
| | fwe | 98/100 | 98.0% | 0.923 |
|
|
||||||
| **Variable Tracking** | vt | 100/100 | 100.0% | 0.934 |
|
|
||||||
| **总计** | **13 数据集** | **1197/1300** | **92.1%** | **0.897** |
|
|
||||||
|
|
||||||
### 分类性能分析
|
|
||||||
|
|
||||||
| 任务类别 | 描述 | 准确率 | 评价 |
|
|
||||||
|----------|------|--------|------|
|
|
||||||
| NIAH Single | 单 needle 检索 | 100% | 优秀 |
|
|
||||||
| NIAH MultiKey | 多 key 检索 | 94.3% | 良好 |
|
|
||||||
| NIAH MultiQuery/Value | 复杂检索 | 100% | 优秀 |
|
|
||||||
| QA | 问答理解 | 65% | 一般 |
|
|
||||||
| Aggregation (CWE/FWE) | 信息聚合 | 92% | 良好 |
|
|
||||||
| Variable Tracking | 变量追踪 | 100% | 优秀 |
|
|
||||||
|
|
||||||
## 发现的问题及修复
|
|
||||||
|
|
||||||
### 问题: FWE 测试崩溃
|
|
||||||
|
|
||||||
**症状**: 第 63 个样本处触发 `AssertionError: No sequences scheduled`
|
|
||||||
|
|
||||||
**根因分析**:
|
|
||||||
1. Sample 63 的输入有 32760 tokens(接近 max_model_len=32768)
|
|
||||||
2. Decode 到第 9 步时,需要第 33 个 KV block
|
|
||||||
3. 但系统只配置了 32 个 blocks(32768/1024=32)
|
|
||||||
4. 调度器尝试 preempt 但单序列模式下无法恢复
|
|
||||||
|
|
||||||
**解决方案**:
|
|
||||||
```python
|
|
||||||
# 修改前
|
|
||||||
DEFAULT_MAX_MODEL_LEN = 32768
|
|
||||||
|
|
||||||
# 修改后: 为 output tokens 预留空间
|
|
||||||
DEFAULT_MAX_MODEL_LEN = 32896 # 32768 + 128
|
|
||||||
```
|
|
||||||
|
|
||||||
**建议的代码改进**:
|
|
||||||
1. 在 scheduler 中添加死锁检测和清晰错误信息
|
|
||||||
2. 在配置验证时,如果 max_model_len 与 max_input 过于接近,发出警告
|
|
||||||
|
|
||||||
## 评估方法
|
|
||||||
|
|
||||||
遵循 RULER 官方评估标准:
|
|
||||||
- **NIAH/VT/CWE/FWE**: `string_match_all` - 召回率 (找到的参考数/总参考数)
|
|
||||||
- **QA**: `string_match_part` - 任意参考匹配即满分
|
|
||||||
|
|
||||||
参考: https://github.com/NVIDIA/RULER
|
|
||||||
|
|
||||||
## 测试配置
|
|
||||||
|
|
||||||
```python
|
|
||||||
LLM(
|
|
||||||
model_path="~/models/Llama-3.1-8B-Instruct",
|
|
||||||
max_model_len=32896,
|
|
||||||
max_num_batched_tokens=32896,
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
num_gpu_blocks=4,
|
|
||||||
kvcache_block_size=1024,
|
|
||||||
enforce_eager=True,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 结论
|
|
||||||
|
|
||||||
1. **长上下文检索能力**: nano-vllm CPU offload 模式在 32K 上下文下表现优秀,NIAH 类任务准确率接近 100%
|
|
||||||
|
|
||||||
2. **复杂推理能力**: QA 任务准确率较低 (65%),这是模型本身能力的体现,与 offload 机制无关
|
|
||||||
|
|
||||||
3. **稳定性**: 修复 max_model_len 配置后,所有 1300 个样本测试均稳定完成
|
|
||||||
|
|
||||||
4. **性能**: 单样本测试时间约 25-35 秒,主要受 CPU-GPU 数据传输影响
|
|
||||||
@@ -1,297 +0,0 @@
|
|||||||
# RULER NIAH Standalone Test Plan
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This document describes how to independently test nano-vllm's CPU offload functionality using RULER benchmark's NIAH (Needle-In-A-Haystack) task data.
|
|
||||||
|
|
||||||
## Background
|
|
||||||
|
|
||||||
### Problem Being Investigated
|
|
||||||
|
|
||||||
When running 32K sequence length tests with CPU offload mode, the model outputs garbled text instead of finding the magic number. This issue was traced to:
|
|
||||||
|
|
||||||
- **Root Cause**: Ring buffer `max_seq_len` was set equal to `max_model_len` (32768)
|
|
||||||
- **Issue**: When prefill uses ~32K tokens, decode needs to store KV at position 32768+, but ring buffer only has indices 0-32767
|
|
||||||
- **Fix Applied**: In `nanovllm/kvcache/__init__.py`, changed `max_seq_len = max_model_len + 512`
|
|
||||||
|
|
||||||
### Test Objective
|
|
||||||
|
|
||||||
Verify that the fix works correctly by running a standalone test with actual RULER NIAH data.
|
|
||||||
|
|
||||||
## Step 1: Copy Test Data
|
|
||||||
|
|
||||||
### Source Location
|
|
||||||
|
|
||||||
```
|
|
||||||
/home/zijie/Code/x-attention/eval/RULER/scripts/benchmark_root/full_fuse_16_llama3.1-8b-chat/synthetic/32768/data/niah_single_1/validation.jsonl
|
|
||||||
```
|
|
||||||
|
|
||||||
### Data Format
|
|
||||||
|
|
||||||
Each line is a JSON object:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"input": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nA special magic number is hidden within the following text...",
|
|
||||||
"outputs": ["8930103"],
|
|
||||||
"length": 32768
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
- `input`: Full prompt with Llama 3.1 chat template (~122K characters, ~30K tokens)
|
|
||||||
- `outputs`: Expected answer (the magic number to find)
|
|
||||||
- `length`: Target sequence length in tokens
|
|
||||||
|
|
||||||
### Copy Command
|
|
||||||
|
|
||||||
```bash
|
|
||||||
mkdir -p /home/zijie/Code/nano-vllm/tests/data/ruler_niah
|
|
||||||
cp /home/zijie/Code/x-attention/eval/RULER/scripts/benchmark_root/full_fuse_16_llama3.1-8b-chat/synthetic/32768/data/niah_single_1/validation.jsonl \
|
|
||||||
/home/zijie/Code/nano-vllm/tests/data/ruler_niah/niah_single_1_32k.jsonl
|
|
||||||
```
|
|
||||||
|
|
||||||
## Step 2: Create Test Script
|
|
||||||
|
|
||||||
Create `/home/zijie/Code/nano-vllm/tests/test_ruler_niah_32k.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
"""
|
|
||||||
Standalone test for RULER NIAH task with 32K context length.
|
|
||||||
|
|
||||||
This test verifies that CPU offload mode correctly handles long sequences
|
|
||||||
where prefill tokens approach max_model_len.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python tests/test_ruler_niah_32k.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import torch
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from nanovllm import LLM
|
|
||||||
from nanovllm.config import SamplingParams
|
|
||||||
|
|
||||||
# Configuration
|
|
||||||
MODEL_PATH = "/data/models/Llama-3.1-8B-Instruct"
|
|
||||||
DATA_FILE = Path(__file__).parent / "data/ruler_niah/niah_single_1_32k.jsonl"
|
|
||||||
MAX_MODEL_LEN = 32768
|
|
||||||
MAX_NEW_TOKENS = 50
|
|
||||||
|
|
||||||
# CPU Offload Settings
|
|
||||||
ENABLE_CPU_OFFLOAD = True
|
|
||||||
NUM_GPU_BLOCKS = 4
|
|
||||||
BLOCK_SIZE = 1024
|
|
||||||
|
|
||||||
|
|
||||||
def load_test_sample(filepath: Path, index: int = 0) -> dict:
|
|
||||||
"""Load a single test sample from JSONL file."""
|
|
||||||
with open(filepath) as f:
|
|
||||||
for i, line in enumerate(f):
|
|
||||||
if i == index:
|
|
||||||
return json.loads(line)
|
|
||||||
raise ValueError(f"Sample index {index} not found")
|
|
||||||
|
|
||||||
|
|
||||||
def test_niah_single():
|
|
||||||
"""Test NIAH single needle task with 32K context."""
|
|
||||||
print("=" * 60)
|
|
||||||
print("RULER NIAH 32K Standalone Test")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Load test data
|
|
||||||
sample = load_test_sample(DATA_FILE, index=0)
|
|
||||||
prompt = sample["input"]
|
|
||||||
expected = sample["outputs"][0]
|
|
||||||
|
|
||||||
print(f"Prompt length: {len(prompt)} characters")
|
|
||||||
print(f"Expected answer: {expected}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Initialize model with CPU offload
|
|
||||||
print("Initializing LLM with CPU offload...")
|
|
||||||
llm = LLM(
|
|
||||||
model=MODEL_PATH,
|
|
||||||
max_model_len=MAX_MODEL_LEN,
|
|
||||||
enable_cpu_offload=ENABLE_CPU_OFFLOAD,
|
|
||||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
|
||||||
kvcache_block_size=BLOCK_SIZE,
|
|
||||||
enforce_eager=True, # Disable CUDA graphs for debugging
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate
|
|
||||||
print("Generating response...")
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.0, # Greedy
|
|
||||||
max_tokens=MAX_NEW_TOKENS,
|
|
||||||
)
|
|
||||||
|
|
||||||
outputs = llm.generate([prompt], sampling_params)
|
|
||||||
generated_text = outputs[0].outputs[0].text
|
|
||||||
|
|
||||||
print()
|
|
||||||
print("=" * 60)
|
|
||||||
print("Results")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Expected: {expected}")
|
|
||||||
print(f"Generated: {generated_text[:200]}...")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Check if expected number is in output
|
|
||||||
if expected in generated_text:
|
|
||||||
print("SUCCESS: Magic number found in output!")
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
print("FAILED: Magic number NOT found in output")
|
|
||||||
print(f"Full output: {generated_text}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def test_multiple_samples(num_samples: int = 5):
|
|
||||||
"""Test multiple NIAH samples."""
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Testing {num_samples} NIAH samples with 32K context")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Initialize model once
|
|
||||||
llm = LLM(
|
|
||||||
model=MODEL_PATH,
|
|
||||||
max_model_len=MAX_MODEL_LEN,
|
|
||||||
enable_cpu_offload=ENABLE_CPU_OFFLOAD,
|
|
||||||
num_gpu_blocks=NUM_GPU_BLOCKS,
|
|
||||||
kvcache_block_size=BLOCK_SIZE,
|
|
||||||
enforce_eager=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.0,
|
|
||||||
max_tokens=MAX_NEW_TOKENS,
|
|
||||||
)
|
|
||||||
|
|
||||||
correct = 0
|
|
||||||
for i in range(num_samples):
|
|
||||||
sample = load_test_sample(DATA_FILE, index=i)
|
|
||||||
prompt = sample["input"]
|
|
||||||
expected = sample["outputs"][0]
|
|
||||||
|
|
||||||
outputs = llm.generate([prompt], sampling_params)
|
|
||||||
generated_text = outputs[0].outputs[0].text
|
|
||||||
|
|
||||||
if expected in generated_text:
|
|
||||||
print(f"Sample {i}: PASS (found {expected})")
|
|
||||||
correct += 1
|
|
||||||
else:
|
|
||||||
print(f"Sample {i}: FAIL (expected {expected}, got: {generated_text[:50]}...)")
|
|
||||||
|
|
||||||
print()
|
|
||||||
print(f"Accuracy: {correct}/{num_samples} ({100*correct/num_samples:.1f}%)")
|
|
||||||
return correct == num_samples
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import sys
|
|
||||||
|
|
||||||
if len(sys.argv) > 1 and sys.argv[1] == "--all":
|
|
||||||
success = test_multiple_samples(5)
|
|
||||||
else:
|
|
||||||
success = test_niah_single()
|
|
||||||
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Step 3: Run Test
|
|
||||||
|
|
||||||
### Single Sample Test
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd /home/zijie/Code/nano-vllm
|
|
||||||
CUDA_VISIBLE_DEVICES=2,3,4,5 python tests/test_ruler_niah_32k.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### All 5 Samples
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd /home/zijie/Code/nano-vllm
|
|
||||||
CUDA_VISIBLE_DEVICES=2,3,4,5 python tests/test_ruler_niah_32k.py --all
|
|
||||||
```
|
|
||||||
|
|
||||||
## Step 4: Expected Results
|
|
||||||
|
|
||||||
### Before Fix (Bug)
|
|
||||||
|
|
||||||
- Output: Garbled text like "not only has been replaced by thesiums..."
|
|
||||||
- Score: 0% (magic number not found)
|
|
||||||
- Time: ~80 seconds per sample
|
|
||||||
|
|
||||||
### After Fix (Expected)
|
|
||||||
|
|
||||||
- Output: The magic number (e.g., "8930103")
|
|
||||||
- Score: ~100% (magic number found)
|
|
||||||
- Time: ~80 seconds per sample (same, as the compute is unchanged)
|
|
||||||
|
|
||||||
## Debugging Tips
|
|
||||||
|
|
||||||
### Enable Verbose Logging
|
|
||||||
|
|
||||||
```python
|
|
||||||
import logging
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Check Ring Buffer Size
|
|
||||||
|
|
||||||
In the logs, verify:
|
|
||||||
```
|
|
||||||
OffloadEngine initializing: num_layers=32, num_kv_buffers=4, max_seq_len=33280
|
|
||||||
```
|
|
||||||
|
|
||||||
The `max_seq_len` should be `32768 + 512 = 33280` (not 32768).
|
|
||||||
|
|
||||||
### Monitor GPU Memory
|
|
||||||
|
|
||||||
```bash
|
|
||||||
watch -n 1 nvidia-smi
|
|
||||||
```
|
|
||||||
|
|
||||||
With CPU offload, GPU memory for KV cache should be ~640MB (ring buffer only).
|
|
||||||
|
|
||||||
## Related Files
|
|
||||||
|
|
||||||
| File | Description |
|
|
||||||
|------|-------------|
|
|
||||||
| `nanovllm/kvcache/__init__.py` | Fix location: `max_seq_len = max_model_len + 512` |
|
|
||||||
| `nanovllm/kvcache/offload_engine.py` | Ring buffer allocation |
|
|
||||||
| `nanovllm/engine/model_runner.py` | Layer-wise offload prefill/decode |
|
|
||||||
| `nanovllm/kvcache/hybrid_manager.py` | CPU block management |
|
|
||||||
|
|
||||||
## Test Data Details
|
|
||||||
|
|
||||||
### NIAH Task Description
|
|
||||||
|
|
||||||
The NIAH (Needle-In-A-Haystack) task tests the model's ability to retrieve a specific piece of information (the "needle") from a large context (the "haystack").
|
|
||||||
|
|
||||||
- **Needle**: A magic number associated with a keyword (e.g., "worried-purse")
|
|
||||||
- **Haystack**: ~30K tokens of distractor text
|
|
||||||
- **Task**: Extract the magic number when asked
|
|
||||||
|
|
||||||
### Sample Prompt Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
|
|
||||||
|
|
||||||
A special magic number is hidden within the following text. Make sure to memorize it. I will quiz you about the number afterwards.
|
|
||||||
|
|
||||||
[... ~30K tokens of haystack text ...]
|
|
||||||
|
|
||||||
The special magic number for worried-purse is 8930103.
|
|
||||||
|
|
||||||
[... more haystack text ...]
|
|
||||||
|
|
||||||
What is the special magic number for worried-purse mentioned in the provided text?
|
|
||||||
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
|
||||||
|
|
||||||
The special magic number for worried-purse mentioned in the provided text is
|
|
||||||
```
|
|
||||||
|
|
||||||
The model should complete with: `8930103`
|
|
||||||
@@ -440,42 +440,3 @@ Required libraries:
|
|||||||
- `minference`: For MInference vertical_slash kernel
|
- `minference`: For MInference vertical_slash kernel
|
||||||
|
|
||||||
Docker image `tzj/xattn:v0.5` has all dependencies pre-installed.
|
Docker image `tzj/xattn:v0.5` has all dependencies pre-installed.
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Quest Sparse Policy (nano-vLLM)
|
|
||||||
|
|
||||||
**Files**: `nanovllm/kvcache/sparse/quest.py`, `nanovllm/kvcache/sparse/policy.py`
|
|
||||||
|
|
||||||
Quest policy is used in nano-vLLM for CPU offload mode. It selects Top-K blocks based on query-key similarity bounds using min/max key metadata.
|
|
||||||
|
|
||||||
### Scoring Mechanism
|
|
||||||
|
|
||||||
```python
|
|
||||||
score_min = torch.einsum('hd,bhd->bh', q, key_min) # [num_blocks, kv_heads]
|
|
||||||
score_max = torch.einsum('hd,bhd->bh', q, key_max) # [num_blocks, kv_heads]
|
|
||||||
scores = torch.maximum(score_min, score_max).mean(dim=-1) # [num_blocks] ← averaged!
|
|
||||||
```
|
|
||||||
|
|
||||||
### Critical Limitation - No Per-Head Scheduling
|
|
||||||
|
|
||||||
The `.mean(dim=-1)` averages scores across all heads, making a **unified** block selection for all heads:
|
|
||||||
|
|
||||||
```
|
|
||||||
Block A: head0 needs (+4), head1 doesn't (-4) → avg = 0 → NOT selected
|
|
||||||
Block B: head0 doesn't (-4), head1 needs (+4) → avg = 0 → NOT selected
|
|
||||||
Block C: both heads moderately need (+2, +2) → avg = +2 → selected
|
|
||||||
```
|
|
||||||
|
|
||||||
### Why Per-Head Scheduling is Infeasible
|
|
||||||
|
|
||||||
1. **Memory Layout**: GPU cache stores all heads together `[block_size, kv_heads, head_dim]`
|
|
||||||
2. **FlashAttention**: Requires complete heads - partial heads cause dimension mismatch
|
|
||||||
3. **Block Granularity**: If any head needs a block, the entire block (all heads) must be loaded
|
|
||||||
|
|
||||||
### Policy Types
|
|
||||||
|
|
||||||
| Policy | `supports_prefill` | `supports_decode` | Description |
|
|
||||||
|--------|-------------------|-------------------|-------------|
|
|
||||||
| `FullAttentionPolicy` | True | True | Loads all blocks (baseline) |
|
|
||||||
| `QuestPolicy` | False | True | Decode-only Top-K selection |
|
|
||||||
|
|||||||
@@ -1,386 +0,0 @@
|
|||||||
# Sparse Policy Integration with Layerwise Offload
|
|
||||||
|
|
||||||
This document describes the architecture and design of integrating sparse attention policies (MInference, Quest) with the layerwise CPU offload execution path.
|
|
||||||
|
|
||||||
## Design Goals
|
|
||||||
|
|
||||||
1. **Extend sparse policies to offload path**: GPU-only path already supports sparse policies, but layerwise offload bypasses them
|
|
||||||
2. **Maintain encapsulation**: All `copy_()` operations must be inside OffloadEngine, not exposed to model_runner
|
|
||||||
3. **Distinguish policy types**: Some policies affect attention computation (MInference), others affect KV load strategy (Quest)
|
|
||||||
4. **Extensible architecture**: Easy to add new sparse policies in the future
|
|
||||||
|
|
||||||
## Key Insight
|
|
||||||
|
|
||||||
The existing sparse policy implementation works, but the layerwise offload path bypasses it:
|
|
||||||
|
|
||||||
| Path | Attention Method | Sparse Support |
|
|
||||||
|------|------------------|----------------|
|
|
||||||
| GPU-only | `attention.py` → `sparse_prefill_attention()` | YES |
|
|
||||||
| Layerwise offload | `model_runner.py` → `flash_attn_varlen_func()` | NO (direct call) |
|
|
||||||
|
|
||||||
## Two Types of Sparse Policies
|
|
||||||
|
|
||||||
The fundamental difference between sparse policies:
|
|
||||||
|
|
||||||
| Policy | Affects Attention Computation | Affects KV Load Strategy | `select_blocks()` Behavior |
|
|
||||||
|--------|------------------------------|--------------------------|---------------------------|
|
|
||||||
| **MInference** | YES (`sparse_prefill_attention`) | NO | `return available_blocks` (all) |
|
|
||||||
| **Quest** | NO | YES | Returns Top-K subset |
|
|
||||||
|
|
||||||
- **MInference**: Only changes how attention is computed, doesn't affect external load/offload flow
|
|
||||||
- **Quest**: Selectively loads only some blocks, affects H2D transfer
|
|
||||||
|
|
||||||
## The `requires_block_selection` Interface Flag
|
|
||||||
|
|
||||||
To distinguish these policy types, we add a flag to the base class:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# nanovllm/kvcache/sparse/policy.py
|
|
||||||
class SparsePolicy(ABC):
|
|
||||||
# Existing flags
|
|
||||||
supports_prefill: bool = True
|
|
||||||
supports_decode: bool = True
|
|
||||||
|
|
||||||
# NEW: Whether this policy requires selective block loading
|
|
||||||
# If True: OffloadEngine will call select_blocks() before loading
|
|
||||||
# If False: OffloadEngine will load all blocks (select_blocks ignored)
|
|
||||||
requires_block_selection: bool = False
|
|
||||||
```
|
|
||||||
|
|
||||||
### Policy Implementations
|
|
||||||
|
|
||||||
```python
|
|
||||||
# MInference: prefill-only, no block selection
|
|
||||||
class MInferencePolicy(SparsePolicy):
|
|
||||||
supports_prefill = True
|
|
||||||
supports_decode = False
|
|
||||||
requires_block_selection = False # Only affects attention computation
|
|
||||||
|
|
||||||
# Quest: decode-only, requires block selection
|
|
||||||
class QuestPolicy(SparsePolicy):
|
|
||||||
supports_prefill = False
|
|
||||||
supports_decode = True
|
|
||||||
requires_block_selection = True # Affects KV load strategy
|
|
||||||
|
|
||||||
# Full attention: baseline
|
|
||||||
class FullAttentionPolicy(SparsePolicy):
|
|
||||||
supports_prefill = True
|
|
||||||
supports_decode = True
|
|
||||||
requires_block_selection = False # Load all blocks
|
|
||||||
```
|
|
||||||
|
|
||||||
## OffloadEngine Encapsulation
|
|
||||||
|
|
||||||
All KV cache operations are encapsulated in OffloadEngine. The model_runner never directly accesses internal storage.
|
|
||||||
|
|
||||||
### Prefill: Synchronous Offload with Hooks
|
|
||||||
|
|
||||||
```python
|
|
||||||
# nanovllm/kvcache/offload_engine.py
|
|
||||||
def offload_layer_kv_sync(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
k: Tensor,
|
|
||||||
v: Tensor,
|
|
||||||
cpu_block_ids: List[int],
|
|
||||||
total_tokens: int,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Synchronously offload layer KV to CPU.
|
|
||||||
Calls sparse policy hooks internally.
|
|
||||||
"""
|
|
||||||
for i, cpu_block_id in enumerate(cpu_block_ids):
|
|
||||||
start = i * self.block_size
|
|
||||||
end = min(start + self.block_size, total_tokens)
|
|
||||||
actual_size = end - start
|
|
||||||
|
|
||||||
# Hook: notify sparse policy BEFORE offload (k still on GPU)
|
|
||||||
if self.sparse_policy is not None:
|
|
||||||
self.sparse_policy.on_prefill_offload(
|
|
||||||
cpu_block_id, layer_id, k[start:end], actual_size
|
|
||||||
)
|
|
||||||
|
|
||||||
# Synchronous copy to CPU (internal)
|
|
||||||
self.k_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(k[start:end])
|
|
||||||
self.v_cache_cpu[layer_id, cpu_block_id, :actual_size].copy_(v[start:end])
|
|
||||||
```
|
|
||||||
|
|
||||||
### Decode: Policy-Driven Block Loading
|
|
||||||
|
|
||||||
```python
|
|
||||||
def load_layer_kv_to_buffer_with_policy(
|
|
||||||
self,
|
|
||||||
buffer_idx: int,
|
|
||||||
layer_id: int,
|
|
||||||
cpu_block_ids: List[int],
|
|
||||||
valid_tokens_per_block: List[int],
|
|
||||||
query: Optional[Tensor] = None,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Load layer KV to buffer, optionally using sparse policy for block selection.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total tokens loaded
|
|
||||||
"""
|
|
||||||
# Check if policy requires block selection
|
|
||||||
if (self.sparse_policy is not None and
|
|
||||||
self.sparse_policy.requires_block_selection and
|
|
||||||
query is not None):
|
|
||||||
# Build context
|
|
||||||
ctx = PolicyContext(
|
|
||||||
query_chunk_idx=0,
|
|
||||||
num_query_chunks=1,
|
|
||||||
layer_id=layer_id,
|
|
||||||
query=query,
|
|
||||||
is_prefill=False,
|
|
||||||
block_size=self.block_size,
|
|
||||||
)
|
|
||||||
# Select blocks using policy
|
|
||||||
selected_blocks = self.sparse_policy.select_blocks(cpu_block_ids, ctx)
|
|
||||||
|
|
||||||
# Build valid_tokens for selected blocks
|
|
||||||
block_to_valid = {bid: vt for bid, vt in zip(cpu_block_ids, valid_tokens_per_block)}
|
|
||||||
selected_valid = [block_to_valid[bid] for bid in selected_blocks]
|
|
||||||
|
|
||||||
return self._load_blocks_to_buffer(
|
|
||||||
buffer_idx, layer_id, selected_blocks, selected_valid
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Load all blocks (no selection)
|
|
||||||
return self._load_blocks_to_buffer(
|
|
||||||
buffer_idx, layer_id, cpu_block_ids, valid_tokens_per_block
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Prefill Integration (MInference)
|
|
||||||
|
|
||||||
MInference only affects attention computation, not the load/offload flow:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# nanovllm/engine/model_runner.py - run_layerwise_offload_prefill()
|
|
||||||
def run_layerwise_offload_prefill(self, seqs):
|
|
||||||
...
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
# QKV projection + RoPE
|
|
||||||
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
|
||||||
|
|
||||||
# Sparse or Full attention
|
|
||||||
if self.sparse_prefill_policy is not None:
|
|
||||||
# MInference: only changes attention computation
|
|
||||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
|
||||||
q, k, v, layer_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Full attention using FlashAttention
|
|
||||||
attn_output = flash_attn_varlen_func(q, k, v, ...)
|
|
||||||
|
|
||||||
# MLP
|
|
||||||
...
|
|
||||||
|
|
||||||
# Offload ALL KV (MInference doesn't affect this)
|
|
||||||
offload_engine.offload_layer_kv_sync(layer_id, k, v, cpu_block_ids, total_tokens)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Execution Flow Diagram
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────────────────────────────────────┐
|
|
||||||
│ Layerwise Offload Prefill │
|
|
||||||
│ with MInference │
|
|
||||||
└─────────────────────────────────────────────────────────────────┘
|
|
||||||
|
|
||||||
For each layer:
|
|
||||||
┌──────────────┐ ┌──────────────┐ ┌────────────────────────┐
|
|
||||||
│ QKV Proj │───▶│ RoPE │───▶│ sparse_prefill_attn() │
|
|
||||||
│ │ │ │ │ (MInference pattern) │
|
|
||||||
└──────────────┘ └──────────────┘ └───────────┬────────────┘
|
|
||||||
│
|
|
||||||
┌──────────────┐ ┌───────────▼────────────┐
|
|
||||||
│ MLP │◀───│ O Projection │
|
|
||||||
│ │ │ │
|
|
||||||
└──────┬───────┘ └────────────────────────┘
|
|
||||||
│
|
|
||||||
┌──────▼───────┐
|
|
||||||
│ offload_ │ K, V still on GPU
|
|
||||||
│ layer_kv_ │───▶ Copy to CPU
|
|
||||||
│ sync() │ (all blocks)
|
|
||||||
└──────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
## Decode Integration (Quest - Infrastructure Ready)
|
|
||||||
|
|
||||||
Quest affects block load strategy. The infrastructure is ready, full integration deferred.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# nanovllm/engine/model_runner.py - run_layerwise_offload_decode()
|
|
||||||
def run_layerwise_offload_decode(self, seqs):
|
|
||||||
...
|
|
||||||
# Preload first N layers (no query available, full load)
|
|
||||||
for i in range(num_preload):
|
|
||||||
loaded_tokens[i] = offload_engine.load_layer_kv_to_buffer(
|
|
||||||
i, i, cpu_block_table, valid_tokens_per_block
|
|
||||||
)
|
|
||||||
|
|
||||||
for layer_id in range(num_layers):
|
|
||||||
current_buffer = layer_id % num_buffers
|
|
||||||
|
|
||||||
# Wait for buffer load
|
|
||||||
offload_engine.wait_buffer_load(current_buffer)
|
|
||||||
|
|
||||||
# QKV projection
|
|
||||||
q, k_new, v_new = ...
|
|
||||||
|
|
||||||
# Get loaded KV from ring buffer
|
|
||||||
k_prefill, v_prefill = offload_engine.get_buffer_kv(
|
|
||||||
current_buffer, loaded_tokens[current_buffer]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Attention
|
|
||||||
...
|
|
||||||
|
|
||||||
# Mark buffer done
|
|
||||||
offload_engine.record_buffer_compute_done(current_buffer)
|
|
||||||
|
|
||||||
# Load next layer
|
|
||||||
# Future: use load_layer_kv_to_buffer_with_policy(query=q) for Quest
|
|
||||||
next_layer = layer_id + num_buffers
|
|
||||||
if next_layer < num_layers:
|
|
||||||
loaded_tokens[current_buffer] = offload_engine.load_layer_kv_to_buffer(
|
|
||||||
current_buffer, next_layer, cpu_block_table, valid_tokens_per_block
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Quest Integration (Future Work)
|
|
||||||
|
|
||||||
When Quest is fully integrated:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Load next layer with Quest block selection
|
|
||||||
if next_layer < num_layers:
|
|
||||||
loaded_tokens[current_buffer] = offload_engine.load_layer_kv_to_buffer_with_policy(
|
|
||||||
current_buffer, next_layer, cpu_block_table, valid_tokens_per_block,
|
|
||||||
query=q # Pass query for block selection
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Challenge**: First N layers are preloaded before query is available, so they must use full load.
|
|
||||||
|
|
||||||
## Configuration
|
|
||||||
|
|
||||||
### Enabling Sparse Policy
|
|
||||||
|
|
||||||
```python
|
|
||||||
from nanovllm import LLM
|
|
||||||
from nanovllm.config import SparsePolicyType
|
|
||||||
|
|
||||||
# GPU-only with MInference
|
|
||||||
llm = LLM(
|
|
||||||
model_path,
|
|
||||||
sparse_policy=SparsePolicyType.MINFERENCE,
|
|
||||||
minference_adaptive_budget=0.3, # 30% of seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
# Offload with MInference
|
|
||||||
llm = LLM(
|
|
||||||
model_path,
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
num_gpu_blocks=2,
|
|
||||||
sparse_policy=SparsePolicyType.MINFERENCE,
|
|
||||||
minference_adaptive_budget=0.3,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### MInference Parameters
|
|
||||||
|
|
||||||
| Parameter | Default | Description |
|
|
||||||
|-----------|---------|-------------|
|
|
||||||
| `minference_adaptive_budget` | 0.3 | Budget as fraction of seq_len (0.3 = 30%) |
|
|
||||||
| `minference_vertical_size` | 1000 | Fixed vertical size (when budget=None) |
|
|
||||||
| `minference_slash_size` | 6096 | Fixed slash size (when budget=None) |
|
|
||||||
| `minference_num_sink_tokens` | 30 | Always-kept initial tokens |
|
|
||||||
| `minference_num_recent_diags` | 100 | Always-kept recent diagonals |
|
|
||||||
|
|
||||||
### Quest Parameters (for future decode integration)
|
|
||||||
|
|
||||||
| Parameter | Default | Description |
|
|
||||||
|-----------|---------|-------------|
|
|
||||||
| `sparse_topk_blocks` | 8 | Top-K blocks to load |
|
|
||||||
| `sparse_threshold_blocks` | 4 | Apply sparse only when blocks > threshold |
|
|
||||||
|
|
||||||
## Sparse Policy Hooks
|
|
||||||
|
|
||||||
Sparse policies can implement hooks for metadata collection:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SparsePolicy(ABC):
|
|
||||||
def on_prefill_offload(
|
|
||||||
self,
|
|
||||||
block_id: int,
|
|
||||||
layer_id: int,
|
|
||||||
key: torch.Tensor,
|
|
||||||
valid_tokens: int,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Hook called during prefill offload BEFORE KV is copied to CPU.
|
|
||||||
Key tensor is still on GPU - can compute metadata efficiently.
|
|
||||||
|
|
||||||
Used by Quest to compute min/max key statistics for block selection.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_decode_offload(
|
|
||||||
self,
|
|
||||||
block_id: int,
|
|
||||||
keys: torch.Tensor, # [num_layers, block_size, kv_heads, head_dim]
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Hook called when decode buffer is offloaded to CPU.
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
## File Changes Summary
|
|
||||||
|
|
||||||
| File | Changes |
|
|
||||||
|------|---------|
|
|
||||||
| `nanovllm/kvcache/sparse/policy.py` | Add `requires_block_selection` attribute |
|
|
||||||
| `nanovllm/kvcache/sparse/minference.py` | Set `requires_block_selection = False` |
|
|
||||||
| `nanovllm/kvcache/sparse/quest.py` | Set `requires_block_selection = True` |
|
|
||||||
| `nanovllm/kvcache/sparse/full_policy.py` | Set `requires_block_selection = False` |
|
|
||||||
| `nanovllm/kvcache/offload_engine.py` | Add `offload_layer_kv_sync()`, sparse hooks |
|
|
||||||
| `nanovllm/engine/model_runner.py` | Integrate sparse policies in offload paths |
|
|
||||||
|
|
||||||
## Key Design Principles
|
|
||||||
|
|
||||||
1. **Encapsulation**: All `copy_()` operations inside OffloadEngine
|
|
||||||
2. **Interface Flag**: `requires_block_selection` declares policy type
|
|
||||||
3. **Separation of Concerns**:
|
|
||||||
- MInference: only `sparse_prefill_attention()` (compute-level)
|
|
||||||
- Quest: `select_blocks()` + hooks (load-level)
|
|
||||||
4. **Hooks Inside Engine**: Policy hooks called within OffloadEngine methods
|
|
||||||
|
|
||||||
## Test Results
|
|
||||||
|
|
||||||
Verified on Qwen3-4B-Instruct-2507 with 32K input:
|
|
||||||
|
|
||||||
```
|
|
||||||
# GPU-only + MInference
|
|
||||||
test_needle.py --model Qwen3-4B --input-len 32768 --enable-minference
|
|
||||||
- Prefill: 3383 tok/s
|
|
||||||
- Output: "7492<|im_end|>"
|
|
||||||
- Result: PASSED
|
|
||||||
|
|
||||||
# Offload + MInference
|
|
||||||
test_needle.py --model Qwen3-4B --input-len 32768 --enable-offload --enable-minference
|
|
||||||
- Prefill: 5373 tok/s
|
|
||||||
- Output: "7492<|im_end|>"
|
|
||||||
- Result: PASSED
|
|
||||||
```
|
|
||||||
|
|
||||||
Both configurations produce identical outputs, confirming correctness.
|
|
||||||
|
|
||||||
## Related Documents
|
|
||||||
|
|
||||||
- [`sparse_attention_guide.md`](sparse_attention_guide.md): Algorithm details for sparse methods
|
|
||||||
- [`architecture_guide.md`](architecture_guide.md): Overall system architecture
|
|
||||||
- [`gpu_only_performance_issue.md`](gpu_only_performance_issue.md): Why offload is faster than GPU-only
|
|
||||||
@@ -1,367 +0,0 @@
|
|||||||
# Sparse Prefill Attention Integration Plan
|
|
||||||
|
|
||||||
## Executive Summary
|
|
||||||
|
|
||||||
本文档整合了 int-minference-1/2/3 三个分支的分析,提出统一的三种稀疏注意力策略(MInference、XAttention、FlexPrefill)集成方案。
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Part 1: 现状分析
|
|
||||||
|
|
||||||
### 1.1 x-attention 仓库策略对比
|
|
||||||
|
|
||||||
| 策略 | Pattern 类型 | 估计方法 | Kernel Backend |
|
|
||||||
|------|-------------|---------|----------------|
|
|
||||||
| **MInference** | Vertical + Slash | Last-64-Q attention → 列/对角线求和 | `vertical_slash_sparse_attention` (minference lib) |
|
|
||||||
| **XAttention** | Block Mask | Stride-based Q/K 下采样 → block 分数 | `block_sparse_attn_func` (MIT-HAN-LAB) |
|
|
||||||
| **FlexPrefill** | Adaptive V+S | Last-block attention + JS 散度自适应 | `triton_block_wise_attention` (custom triton) |
|
|
||||||
|
|
||||||
### 1.2 关键发现:两种 Kernel 接口
|
|
||||||
|
|
||||||
**接口 A: Index-Based (minference)**
|
|
||||||
```python
|
|
||||||
# MInference 使用 vertical+slash indices
|
|
||||||
vertical_indices = [heads, vertical_size] # 重要 K 列位置
|
|
||||||
slash_indices = [heads, slash_size] # 对角线偏移
|
|
||||||
output = vertical_slash_sparse_attention(q, k, v, vertical_indices, slash_indices)
|
|
||||||
```
|
|
||||||
|
|
||||||
**接口 B: Block Mask-Based (block_sparse_attn)**
|
|
||||||
```python
|
|
||||||
# XAttention/FlexPrefill 使用 boolean block mask
|
|
||||||
block_mask = torch.bool[batch, heads, q_blocks, k_blocks] # True = 计算
|
|
||||||
output = block_sparse_attn_func(q, k, v, block_mask, ...)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 1.3 当前 nanovllm MInference 实现
|
|
||||||
|
|
||||||
**文件**: `nanovllm/kvcache/sparse/minference.py`
|
|
||||||
|
|
||||||
**已实现功能**:
|
|
||||||
- `estimate_pattern()`: 使用 last-64-Q 估计 vertical+slash pattern
|
|
||||||
- `sparse_prefill_attention()`: 调用 minference kernel 执行稀疏注意力
|
|
||||||
- 支持 GQA(通过 K/V repeat_interleave)
|
|
||||||
- 支持 adaptive_budget 自适应预算
|
|
||||||
|
|
||||||
**问题**:
|
|
||||||
1. 与 XAttention/FlexPrefill 使用不同 kernel,无法统一接口
|
|
||||||
2. `sparse_prefill_attention()` 将估计和执行耦合在一起
|
|
||||||
3. 没有 BlockMask 中间表示,难以复用
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Part 2: 架构设计
|
|
||||||
|
|
||||||
### 2.1 设计原则
|
|
||||||
|
|
||||||
1. **向后兼容**: 保持现有 `SparsePolicy` 接口不变
|
|
||||||
2. **渐进式重构**: 添加新功能而非替换
|
|
||||||
3. **统一中间表示**: 新策略使用 `BlockMask` 作为可选中间表示
|
|
||||||
4. **可插拔 Kernel**: 支持多种 attention kernel backend
|
|
||||||
|
|
||||||
### 2.2 架构图
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────────────────────────────────────────────────────────────────────┐
|
|
||||||
│ Unified Sparse Prefill Framework │
|
|
||||||
├──────────────────────────────────────────────────────────────────────────────┤
|
|
||||||
│ │
|
|
||||||
│ ┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ │
|
|
||||||
│ │ MInference │ │ XAttention │ │ FlexPrefill │ Strategies │
|
|
||||||
│ │ Policy │ │ Policy │ │ Policy │ │
|
|
||||||
│ └────────┬────────┘ └────────┬────────┘ └────────┬────────┘ │
|
|
||||||
│ │ │ │ │
|
|
||||||
│ │ (indices) │ (BlockMask) │ (BlockMask) │
|
|
||||||
│ │ │ │ │
|
|
||||||
│ ▼ └────────┬───────────┘ │
|
|
||||||
│ ┌─────────────────┐ ▼ │
|
|
||||||
│ │ minference │ ┌─────────────────────────────────────────────────────┐│
|
|
||||||
│ │ kernel │ │ BlockMask Container ││
|
|
||||||
│ └────────┬────────┘ │ [batch, num_heads, q_blocks, k_blocks] - boolean ││
|
|
||||||
│ │ └─────────────────────────────────────────────────────┘│
|
|
||||||
│ │ │ │
|
|
||||||
│ │ ▼ │
|
|
||||||
│ │ ┌─────────────────────────────────────────────────────┐│
|
|
||||||
│ │ │ block_sparse_attn_func ││
|
|
||||||
│ │ │ (MIT-HAN-LAB kernel) ││
|
|
||||||
│ │ └─────────────────────────────────────────────────────┘│
|
|
||||||
│ │ │ │
|
|
||||||
│ └──────────────────────────────┼────────────────────────────────── │
|
|
||||||
│ ▼ │
|
|
||||||
│ ┌─────────────────────────────────────────────────────────────────────────┐ │
|
|
||||||
│ │ Attention Output │ │
|
|
||||||
│ │ [seq_len, num_heads, head_dim] │ │
|
|
||||||
│ └─────────────────────────────────────────────────────────────────────────┘ │
|
|
||||||
│ │
|
|
||||||
└──────────────────────────────────────────────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.3 新增类设计
|
|
||||||
|
|
||||||
```python
|
|
||||||
# nanovllm/kvcache/sparse/block_mask.py
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BlockMask:
|
|
||||||
"""Block-level attention mask container."""
|
|
||||||
mask: torch.Tensor # [batch, heads, q_blocks, k_blocks]
|
|
||||||
block_size: int
|
|
||||||
seq_len: int
|
|
||||||
num_q_blocks: int
|
|
||||||
num_k_blocks: int
|
|
||||||
|
|
||||||
def sparsity_ratio(self) -> float:
|
|
||||||
"""Fraction of blocks masked out."""
|
|
||||||
return 1.0 - self.mask.float().mean().item()
|
|
||||||
|
|
||||||
def to_flat_indices(self, head_idx: int) -> torch.Tensor:
|
|
||||||
"""Convert to flattened block indices for a given head."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_vertical_slash(
|
|
||||||
cls,
|
|
||||||
vertical_idx: torch.Tensor,
|
|
||||||
slash_idx: torch.Tensor,
|
|
||||||
seq_len: int,
|
|
||||||
block_size: int,
|
|
||||||
) -> "BlockMask":
|
|
||||||
"""Convert MInference-style indices to block mask."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def apply_causal(self) -> "BlockMask":
|
|
||||||
"""Apply causal constraint (lower triangular)."""
|
|
||||||
pass
|
|
||||||
```
|
|
||||||
|
|
||||||
```python
|
|
||||||
# nanovllm/kvcache/sparse/kernels/block_sparse.py
|
|
||||||
|
|
||||||
def block_sparse_attention(
|
|
||||||
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
|
||||||
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
|
||||||
block_mask: BlockMask,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Execute block sparse attention using MIT-HAN-LAB kernel.
|
|
||||||
|
|
||||||
Handles:
|
|
||||||
- GQA expansion (K/V heads < Q heads)
|
|
||||||
- Tensor format conversion
|
|
||||||
- Causal masking
|
|
||||||
"""
|
|
||||||
from block_sparse_attn import block_sparse_attn_func
|
|
||||||
# ... implementation
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Part 3: 实现计划
|
|
||||||
|
|
||||||
### Phase 1: 基础设施 (新增文件)
|
|
||||||
|
|
||||||
**目标**: 添加 BlockMask 和 block_sparse_attn 封装
|
|
||||||
|
|
||||||
**文件**:
|
|
||||||
- `nanovllm/kvcache/sparse/block_mask.py` (NEW)
|
|
||||||
- `nanovllm/kvcache/sparse/kernels/__init__.py` (NEW)
|
|
||||||
- `nanovllm/kvcache/sparse/kernels/block_sparse.py` (NEW)
|
|
||||||
|
|
||||||
**任务**:
|
|
||||||
1. 实现 `BlockMask` 数据类
|
|
||||||
2. 实现 `block_sparse_attention()` 封装函数
|
|
||||||
3. 处理 GQA 和 tensor 格式转换
|
|
||||||
4. 测试:使用全 True 的 block mask 验证输出正确
|
|
||||||
|
|
||||||
### Phase 2: XAttention 实现
|
|
||||||
|
|
||||||
**目标**: 移植 x-attention 的 XAttention 策略
|
|
||||||
|
|
||||||
**文件**:
|
|
||||||
- `nanovllm/kvcache/sparse/xattention.py` (NEW)
|
|
||||||
- `nanovllm/config.py` (添加 XATTENTION 枚举)
|
|
||||||
- `nanovllm/kvcache/sparse/__init__.py` (更新工厂函数)
|
|
||||||
|
|
||||||
**关键函数移植**:
|
|
||||||
```python
|
|
||||||
# From x-attention/xattn/src/Xattention.py
|
|
||||||
def xattn_estimate(q, k, block_size, stride, threshold, ...):
|
|
||||||
# 1. Stride-based Q/K downsampling
|
|
||||||
reshaped_k = cat([k[:, :, i::stride, :] for i in range(stride)], dim=-1)
|
|
||||||
reshaped_q = cat([q[:, :, stride-1-i::stride, :] for i in range(stride)], dim=-1)
|
|
||||||
|
|
||||||
# 2. Block-level attention scores
|
|
||||||
attn_weights = matmul(reshaped_q, reshaped_k.T) / sqrt(d) / stride
|
|
||||||
|
|
||||||
# 3. Threshold selection
|
|
||||||
block_mask = find_blocks_chunked(attn_sum, threshold)
|
|
||||||
return block_mask
|
|
||||||
```
|
|
||||||
|
|
||||||
**配置参数**:
|
|
||||||
```python
|
|
||||||
xattention_stride: int = 16 # Q/K 下采样步长
|
|
||||||
xattention_threshold: float = 0.9 # 累积分数阈值
|
|
||||||
xattention_block_size: int = 128 # Block 大小
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试**: `python tests/test_needle.py --input-len 32768 --enable-xattention`
|
|
||||||
|
|
||||||
### Phase 3: FlexPrefill 实现
|
|
||||||
|
|
||||||
**目标**: 移植 x-attention 的 FlexPrefill 策略
|
|
||||||
|
|
||||||
**文件**:
|
|
||||||
- `nanovllm/kvcache/sparse/flexprefill.py` (NEW)
|
|
||||||
- `nanovllm/config.py` (添加 FLEXPREFILL 枚举)
|
|
||||||
|
|
||||||
**关键函数移植**:
|
|
||||||
```python
|
|
||||||
# From x-attention/xattn/src/Flexprefill.py
|
|
||||||
def get_active_blocks(q, k, gamma, tau, block_size, ...):
|
|
||||||
# 1. Last-block attention analysis
|
|
||||||
last_q = q[:, -block_size:, :, :]
|
|
||||||
qk = einsum('bihd,bjhd->bhij', last_q, k)
|
|
||||||
|
|
||||||
# 2. Vertical + slash pattern detection
|
|
||||||
vertical = qk.mean(-2) # Column importance
|
|
||||||
slash = sum_all_diagonal_matrix(qk) # Diagonal importance
|
|
||||||
|
|
||||||
# 3. JS divergence for adaptive budget
|
|
||||||
kl_div = js_divergence(avg_qk, vertical_pooled)
|
|
||||||
is_sparse_head = kl_div > tau
|
|
||||||
budget = gamma if is_sparse_head else 1.0
|
|
||||||
|
|
||||||
# 4. Select blocks
|
|
||||||
block_idx = transform_vertical_slash_idx(...)
|
|
||||||
return block_mask
|
|
||||||
```
|
|
||||||
|
|
||||||
**配置参数**:
|
|
||||||
```python
|
|
||||||
flexprefill_gamma: float = 0.9 # 基础覆盖率
|
|
||||||
flexprefill_tau: float = 0.1 # JS 散度阈值
|
|
||||||
flexprefill_min_budget: int = 128 # 最小 token 预算
|
|
||||||
flexprefill_block_size: int = 128 # Block 大小
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试**: `python tests/test_needle.py --input-len 32768 --enable-flexprefill`
|
|
||||||
|
|
||||||
### Phase 4: MInference 可选重构
|
|
||||||
|
|
||||||
**目标**: (可选) 让 MInference 也可以使用 block_sparse_attn
|
|
||||||
|
|
||||||
**修改文件**:
|
|
||||||
- `nanovllm/kvcache/sparse/minference.py`
|
|
||||||
|
|
||||||
**新增方法**:
|
|
||||||
```python
|
|
||||||
class MInferencePolicy(SparsePolicy):
|
|
||||||
def __init__(self, ..., use_block_sparse: bool = False):
|
|
||||||
self.use_block_sparse = use_block_sparse
|
|
||||||
|
|
||||||
def estimate_block_mask(self, q, k, layer_id) -> BlockMask:
|
|
||||||
"""Convert vertical+slash indices to BlockMask."""
|
|
||||||
vertical_idx, slash_idx = self.estimate_pattern(q, k, layer_id)
|
|
||||||
return BlockMask.from_vertical_slash(vertical_idx, slash_idx, ...)
|
|
||||||
|
|
||||||
def sparse_prefill_attention(self, q, k, v, layer_id):
|
|
||||||
if self.use_block_sparse:
|
|
||||||
block_mask = self.estimate_block_mask(q, k, layer_id)
|
|
||||||
return block_sparse_attention(q, k, v, block_mask)
|
|
||||||
else:
|
|
||||||
# 使用原有 minference kernel
|
|
||||||
return self._minference_kernel_attention(q, k, v, layer_id)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Phase 5: 集成和测试
|
|
||||||
|
|
||||||
**任务**:
|
|
||||||
1. 更新 `__init__.py` 工厂函数支持所有策略
|
|
||||||
2. 更新 Config 添加所有配置参数
|
|
||||||
3. 添加性能基准测试脚本
|
|
||||||
4. 更新文档
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Part 4: 依赖管理
|
|
||||||
|
|
||||||
### 必需依赖
|
|
||||||
|
|
||||||
```
|
|
||||||
# requirements.txt 新增
|
|
||||||
block-sparse-attn # MIT-HAN-LAB block sparse kernel
|
|
||||||
triton>=2.0 # FlexPrefill Triton kernels
|
|
||||||
```
|
|
||||||
|
|
||||||
### 安装说明
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# block_sparse_attn from MIT-HAN-LAB
|
|
||||||
pip install git+https://github.com/mit-han-lab/Block-Sparse-Attention.git
|
|
||||||
|
|
||||||
# 或从本地安装(如果有)
|
|
||||||
cd /home/zijie/Code/x-attention/Block-Sparse-Attention
|
|
||||||
pip install -e .
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Part 5: 配置参数汇总
|
|
||||||
|
|
||||||
### SparsePolicyType 枚举
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SparsePolicyType(str, Enum):
|
|
||||||
FULL = "full" # 全注意力(无稀疏)
|
|
||||||
QUEST = "quest" # Decode-only Top-K
|
|
||||||
MINFERENCE = "minference" # Prefill vertical+slash
|
|
||||||
XATTENTION = "xattention" # Prefill stride-based block
|
|
||||||
FLEXPREFILL = "flexprefill" # Prefill adaptive JS-divergence
|
|
||||||
```
|
|
||||||
|
|
||||||
### 策略参数对照表
|
|
||||||
|
|
||||||
| 策略 | 参数 | 默认值 | 说明 |
|
|
||||||
|------|-----|--------|------|
|
|
||||||
| MInference | `adaptive_budget` | 0.3 | 预算占 seq_len 比例 |
|
|
||||||
| MInference | `vertical_size` | 1000 | 固定 vertical 大小 |
|
|
||||||
| MInference | `slash_size` | 6096 | 固定 slash 大小 |
|
|
||||||
| XAttention | `stride` | 16 | Q/K 下采样步长 |
|
|
||||||
| XAttention | `threshold` | 0.9 | 累积分数阈值 |
|
|
||||||
| XAttention | `block_size` | 128 | Block 大小 |
|
|
||||||
| FlexPrefill | `gamma` | 0.9 | 基础覆盖率 |
|
|
||||||
| FlexPrefill | `tau` | 0.1 | JS 散度阈值 |
|
|
||||||
| FlexPrefill | `min_budget` | 128 | 最小 token 预算 |
|
|
||||||
| FlexPrefill | `block_size` | 128 | Block 大小 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Part 6: 成功标准
|
|
||||||
|
|
||||||
1. **正确性**: 所有三种策略通过 32K+ needle-in-haystack 测试
|
|
||||||
2. **性能**: 稀疏 prefill 比全注意力快 (>1.5x speedup at 64K)
|
|
||||||
3. **统一接口**: XAttention/FlexPrefill 使用 BlockMask + block_sparse_attn
|
|
||||||
4. **向后兼容**: 现有 MInference 配置继续工作
|
|
||||||
5. **可配置**: 所有策略参数可通过 LLM 配置设置
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Part 7: 风险评估
|
|
||||||
|
|
||||||
| 风险 | 影响 | 可能性 | 缓解措施 |
|
|
||||||
|------|-----|--------|---------|
|
|
||||||
| block_sparse_attn 硬件兼容性 | 高 | 中 | 测试目标硬件,fallback 到 flash_attn |
|
|
||||||
| MInference → block mask 精度损失 | 中 | 低 | 对比测试输出差异 |
|
|
||||||
| Triton kernel 移植问题 | 中 | 中 | 使用非 Triton fallback |
|
|
||||||
| 内存开销增加 | 低 | 低 | block_size=128 → 1KB/head for 128K |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## References
|
|
||||||
|
|
||||||
- x-attention repo: `/home/zijie/Code/x-attention`
|
|
||||||
- MIT-HAN-LAB Block-Sparse-Attention: `https://github.com/mit-han-lab/Block-Sparse-Attention`
|
|
||||||
- MInference paper: https://arxiv.org/abs/2407.02490
|
|
||||||
- Current nanovllm sparse implementation: `nanovllm/kvcache/sparse/`
|
|
||||||
@@ -1,279 +0,0 @@
|
|||||||
# Transformers 低版本兼容性问题
|
|
||||||
|
|
||||||
## 概述
|
|
||||||
|
|
||||||
本文档详细记录了 nano-vllm 在低版本 transformers(< 4.51.0)环境下的兼容性问题。这些问题源于 nano-vllm 使用了 transformers 4.51.0 才引入的 `Qwen3Config` 类。
|
|
||||||
|
|
||||||
## 问题背景
|
|
||||||
|
|
||||||
### 测试环境
|
|
||||||
|
|
||||||
| 环境 | 版本 | 说明 |
|
|
||||||
|------|------|------|
|
|
||||||
| Docker 镜像 | `tzj/ruler:v0.3` | NVIDIA PyTorch 24.08 容器 |
|
|
||||||
| transformers | 4.45.2 | 系统预装版本 |
|
|
||||||
| Python | 3.10.12 | 系统版本 |
|
|
||||||
| PyTorch | 2.5.0a0+872d972 | CUDA 12.6 |
|
|
||||||
|
|
||||||
### 冲突场景
|
|
||||||
|
|
||||||
在 RULER benchmark 测试环境中,NeMo 框架依赖 transformers 4.45.2 和特定版本的 `huggingface_hub`。升级 transformers 到 4.51.0+ 会导致:
|
|
||||||
|
|
||||||
```
|
|
||||||
ImportError: cannot import name 'ModelFilter' from 'huggingface_hub'
|
|
||||||
```
|
|
||||||
|
|
||||||
因此需要 nano-vllm 适配低版本 transformers,以便在同一环境中运行。
|
|
||||||
|
|
||||||
## 详细问题分析
|
|
||||||
|
|
||||||
### 1. 核心问题:Qwen3Config 不存在
|
|
||||||
|
|
||||||
**错误信息**:
|
|
||||||
```python
|
|
||||||
ImportError: cannot import name 'Qwen3Config' from 'transformers'
|
|
||||||
(/usr/local/lib/python3.10/dist-packages/transformers/__init__.py)
|
|
||||||
```
|
|
||||||
|
|
||||||
**问题根源**:
|
|
||||||
- `Qwen3Config` 是在 transformers **4.51.0** 版本中首次引入
|
|
||||||
- transformers 4.45.2 只包含 `Qwen2` 系列模型
|
|
||||||
|
|
||||||
**受影响版本**:
|
|
||||||
| transformers 版本 | Qwen3 支持 | 可用 Qwen 模型 |
|
|
||||||
|------------------|-----------|---------------|
|
|
||||||
| < 4.51.0 | 不支持 | qwen2, qwen2_audio, qwen2_moe, qwen2_vl |
|
|
||||||
| >= 4.51.0 | 支持 | qwen2 系列 + qwen3, qwen3_moe |
|
|
||||||
|
|
||||||
### 2. 影响范围
|
|
||||||
|
|
||||||
#### 2.1 直接影响的文件
|
|
||||||
|
|
||||||
| 文件路径 | 问题代码 | 影响 |
|
|
||||||
|---------|---------|------|
|
|
||||||
| `nanovllm/models/qwen3.py:4` | `from transformers import Qwen3Config` | 直接导入失败 |
|
|
||||||
| `nanovllm/models/__init__.py:6` | `from nanovllm.models import qwen3` | 触发 qwen3 导入 |
|
|
||||||
|
|
||||||
#### 2.2 级联影响
|
|
||||||
|
|
||||||
由于 `nanovllm/models/__init__.py` 无条件导入了 `qwen3` 模块,会导致以下级联失败:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 这些导入都会失败
|
|
||||||
from nanovllm.models import llama # FAILED
|
|
||||||
from nanovllm.models import get_model_class # FAILED
|
|
||||||
import nanovllm # FAILED
|
|
||||||
```
|
|
||||||
|
|
||||||
**测试验证**:
|
|
||||||
```python
|
|
||||||
# transformers 4.45.2 环境
|
|
||||||
|
|
||||||
>>> from nanovllm.models.registry import register_model
|
|
||||||
SUCCESS # registry 本身可以导入
|
|
||||||
|
|
||||||
>>> from nanovllm.config import Config
|
|
||||||
SUCCESS # config 不依赖 Qwen3Config
|
|
||||||
|
|
||||||
>>> from nanovllm.models import llama
|
|
||||||
FAILED: cannot import name 'Qwen3Config' from 'transformers'
|
|
||||||
# 因为 models/__init__.py 先导入了 qwen3
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Qwen3Config 使用位置
|
|
||||||
|
|
||||||
在 `nanovllm/models/qwen3.py` 中的使用:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Line 4
|
|
||||||
from transformers import Qwen3Config
|
|
||||||
|
|
||||||
# Line 128-129: 类型注解
|
|
||||||
class Qwen3DecoderLayer(nn.Module):
|
|
||||||
def __init__(self, config: Qwen3Config) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
# Line 170-171: 类型注解
|
|
||||||
class Qwen3Model(nn.Module):
|
|
||||||
def __init__(self, config: Qwen3Config) -> None:
|
|
||||||
...
|
|
||||||
|
|
||||||
# Line 200-203: 类型注解
|
|
||||||
class Qwen3ForCausalLM(nn.Module):
|
|
||||||
def __init__(self, config: Qwen3Config) -> None:
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Qwen3Config 属性使用
|
|
||||||
|
|
||||||
代码中使用了以下 `Qwen3Config` 属性:
|
|
||||||
|
|
||||||
| 属性 | 位置 | 用途 |
|
|
||||||
|------|------|------|
|
|
||||||
| `hidden_size` | Line 131, 147, 173 | 隐藏层维度 |
|
|
||||||
| `num_attention_heads` | Line 132 | 注意力头数 |
|
|
||||||
| `num_key_value_heads` | Line 133 | KV 头数 |
|
|
||||||
| `max_position_embeddings` | Line 134 | 最大位置编码 |
|
|
||||||
| `rms_norm_eps` | Line 135, 147, 148, 175 | RMSNorm epsilon |
|
|
||||||
| `attention_bias` | Line 136 (getattr) | 是否使用注意力偏置 |
|
|
||||||
| `head_dim` | Line 137 (getattr) | 注意力头维度 |
|
|
||||||
| `rope_theta` | Line 138 (getattr) | RoPE base |
|
|
||||||
| `rope_scaling` | Line 139 (getattr) | RoPE scaling 配置 |
|
|
||||||
| `intermediate_size` | Line 144 | FFN 中间层维度 |
|
|
||||||
| `hidden_act` | Line 145 | 激活函数类型 |
|
|
||||||
| `vocab_size` | Line 173, 206 | 词表大小 |
|
|
||||||
| `num_hidden_layers` | Line 174 | Transformer 层数 |
|
|
||||||
| `tie_word_embeddings` | Line 207 | 是否共享词嵌入 |
|
|
||||||
|
|
||||||
## 解决方案建议
|
|
||||||
|
|
||||||
### 方案 1: 条件导入(推荐)
|
|
||||||
|
|
||||||
修改 `nanovllm/models/__init__.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
"""Model registry and model implementations."""
|
|
||||||
|
|
||||||
from nanovllm.models.registry import register_model, get_model_class, MODEL_REGISTRY
|
|
||||||
|
|
||||||
# Import models to trigger registration
|
|
||||||
# Llama is always available
|
|
||||||
from nanovllm.models import llama
|
|
||||||
|
|
||||||
# Qwen3 requires transformers >= 4.51.0
|
|
||||||
try:
|
|
||||||
from nanovllm.models import qwen3
|
|
||||||
except ImportError:
|
|
||||||
import warnings
|
|
||||||
warnings.warn(
|
|
||||||
"Qwen3 models require transformers >= 4.51.0. "
|
|
||||||
"Install with: pip install 'transformers>=4.51.0'"
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"]
|
|
||||||
```
|
|
||||||
|
|
||||||
修改 `nanovllm/models/qwen3.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import torch
|
|
||||||
from torch import nn
|
|
||||||
import torch.distributed as dist
|
|
||||||
|
|
||||||
# Conditional import for Qwen3Config
|
|
||||||
try:
|
|
||||||
from transformers import Qwen3Config
|
|
||||||
except ImportError:
|
|
||||||
# Create a placeholder for type hints when Qwen3Config is not available
|
|
||||||
Qwen3Config = None
|
|
||||||
raise ImportError(
|
|
||||||
"Qwen3Config requires transformers >= 4.51.0. "
|
|
||||||
"Current version does not support Qwen3 models."
|
|
||||||
)
|
|
||||||
|
|
||||||
# ... rest of the code
|
|
||||||
```
|
|
||||||
|
|
||||||
### 方案 2: 使用 AutoConfig(兼容性更好)
|
|
||||||
|
|
||||||
修改 `nanovllm/models/qwen3.py` 以使用 `AutoConfig` 而非具体的 `Qwen3Config`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
# Only import Qwen3Config for type checking
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from transformers import Qwen3Config
|
|
||||||
|
|
||||||
# Runtime: use duck typing
|
|
||||||
class Qwen3DecoderLayer(nn.Module):
|
|
||||||
def __init__(self, config: Any) -> None: # Accept any config-like object
|
|
||||||
super().__init__()
|
|
||||||
# Access attributes via getattr for safety
|
|
||||||
self.self_attn = Qwen3Attention(
|
|
||||||
hidden_size=config.hidden_size,
|
|
||||||
num_heads=config.num_attention_heads,
|
|
||||||
num_kv_heads=config.num_key_value_heads,
|
|
||||||
max_position=config.max_position_embeddings,
|
|
||||||
rms_norm_eps=config.rms_norm_eps,
|
|
||||||
qkv_bias=getattr(config, 'attention_bias', True),
|
|
||||||
head_dim=getattr(config, 'head_dim', None),
|
|
||||||
rope_theta=getattr(config, "rope_theta", 1000000),
|
|
||||||
rope_scaling=getattr(config, "rope_scaling", None),
|
|
||||||
)
|
|
||||||
# ...
|
|
||||||
```
|
|
||||||
|
|
||||||
### 方案 3: 版本检查与优雅降级
|
|
||||||
|
|
||||||
在 `nanovllm/__init__.py` 或启动时添加版本检查:
|
|
||||||
|
|
||||||
```python
|
|
||||||
import transformers
|
|
||||||
from packaging import version
|
|
||||||
|
|
||||||
TRANSFORMERS_VERSION = version.parse(transformers.__version__)
|
|
||||||
QWEN3_MIN_VERSION = version.parse("4.51.0")
|
|
||||||
|
|
||||||
QWEN3_AVAILABLE = TRANSFORMERS_VERSION >= QWEN3_MIN_VERSION
|
|
||||||
|
|
||||||
if not QWEN3_AVAILABLE:
|
|
||||||
import warnings
|
|
||||||
warnings.warn(
|
|
||||||
f"transformers {transformers.__version__} does not support Qwen3 models. "
|
|
||||||
f"Upgrade to >= 4.51.0 for Qwen3 support."
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## 适配优先级
|
|
||||||
|
|
||||||
建议按以下优先级进行适配:
|
|
||||||
|
|
||||||
1. **P0 - models/__init__.py**: 添加 try-except 使 Llama 模型可独立使用
|
|
||||||
2. **P1 - qwen3.py**: 添加清晰的错误信息,说明版本要求
|
|
||||||
3. **P2 - 类型注解**: 可选地改为 `Any` 或使用 `TYPE_CHECKING`
|
|
||||||
4. **P3 - 文档**: 在 README 和 pyproject.toml 中说明版本依赖
|
|
||||||
|
|
||||||
## 测试验证
|
|
||||||
|
|
||||||
适配后应验证以下场景:
|
|
||||||
|
|
||||||
### 测试 1: 低版本环境(transformers 4.45.2)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 预期结果:Llama 模型可用,Qwen3 提示版本不足
|
|
||||||
docker run --rm \
|
|
||||||
-v /path/to/nano-vllm:/workspace/nano-vllm \
|
|
||||||
-e PYTHONPATH=/workspace/nano-vllm \
|
|
||||||
tzj/ruler:v0.3 \
|
|
||||||
python -c "
|
|
||||||
from nanovllm.models import get_model_class, MODEL_REGISTRY
|
|
||||||
print('Available models:', list(MODEL_REGISTRY.keys()))
|
|
||||||
# Expected: ['LlamaForCausalLM']
|
|
||||||
# Warning: Qwen3 models require transformers >= 4.51.0
|
|
||||||
"
|
|
||||||
```
|
|
||||||
|
|
||||||
### 测试 2: 高版本环境(transformers >= 4.51.0)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 预期结果:Llama 和 Qwen3 模型均可用
|
|
||||||
pip install 'transformers>=4.51.0'
|
|
||||||
python -c "
|
|
||||||
from nanovllm.models import get_model_class, MODEL_REGISTRY
|
|
||||||
print('Available models:', list(MODEL_REGISTRY.keys()))
|
|
||||||
# Expected: ['LlamaForCausalLM', 'Qwen3ForCausalLM', 'Qwen2ForCausalLM']
|
|
||||||
"
|
|
||||||
```
|
|
||||||
|
|
||||||
## 相关参考
|
|
||||||
|
|
||||||
- [Transformers Qwen3 文档](https://huggingface.co/docs/transformers/en/model_doc/qwen3)
|
|
||||||
- [Qwen3 GitHub](https://github.com/QwenLM/Qwen3)
|
|
||||||
- [Transformers 版本历史](https://github.com/huggingface/transformers/releases)
|
|
||||||
|
|
||||||
## 版本信息
|
|
||||||
|
|
||||||
| 日期 | 版本 | 变更 |
|
|
||||||
|------|------|------|
|
|
||||||
| 2025-01-11 | 1.0 | 初始文档,记录 transformers 4.45.2 兼容性问题 |
|
|
||||||
@@ -1,597 +0,0 @@
|
|||||||
# COMPASS XAttention Implementation Analysis
|
|
||||||
|
|
||||||
**Analysis Date**: 2026-01-14
|
|
||||||
**Researcher**: Claude Code Agent
|
|
||||||
**Source**: `/home/zijie/Code/COMPASS/compass/src/`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Executive Summary
|
|
||||||
|
|
||||||
COMPASS XAttention is a **block sparse attention** implementation that uses:
|
|
||||||
1. **Approximation phase** (`xattn_estimate`) to compute attention importance and select blocks
|
|
||||||
2. **Computation phase** (`Xattention_prefill`) to compute sparse attention using `block_sparse_attn_func`
|
|
||||||
3. **Triton kernels** for efficient block-wise GEMM and softmax operations
|
|
||||||
|
|
||||||
**Key Integration Constraint**: Requires `block_sparse_attn_func` from flash-attention library, which is a **C++ CUDA extension** that must be compiled separately.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. Function: `xattn_estimate()`
|
|
||||||
|
|
||||||
**Purpose**: Estimate attention importance and select which blocks to compute
|
|
||||||
|
|
||||||
### Input Parameters
|
|
||||||
|
|
||||||
| Parameter | Type | Default | Description |
|
|
||||||
|-----------|------|---------|-------------|
|
|
||||||
| `query_states` | Tensor | - | Shape: `(batch, num_heads, q_len, head_dim)` |
|
|
||||||
| `key_states` | Tensor | - | Shape: `(batch, num_kv_heads, k_len, head_dim)` |
|
|
||||||
| `block_size` | int | - | Size of attention blocks (typically 128) |
|
|
||||||
| `stride` | int | - | Downsampling stride for approximation |
|
|
||||||
| `norm` | float | 1 | Normalization factor for attention scaling |
|
|
||||||
| `softmax` | bool | True | Whether to apply softmax in estimation |
|
|
||||||
| `threshold` | float | 0.9 | Block selection threshold (0-1) |
|
|
||||||
| `chunk_size` | int | 16384 | Processing chunk size |
|
|
||||||
| `select_mode` | str | "inverse" | Pattern selection mode |
|
|
||||||
| `use_triton` | bool | True | Use Triton kernels (requires SM 80+) |
|
|
||||||
| `causal` | bool | True | Apply causal masking |
|
|
||||||
| `kdb` | int | 1 | Key downsampling factor |
|
|
||||||
| `keep_sink` | bool | False | Always attend to first token |
|
|
||||||
| `keep_recent` | bool | False | Always attend to recent tokens |
|
|
||||||
|
|
||||||
### Output
|
|
||||||
|
|
||||||
```python
|
|
||||||
returns: (attn_sums, simple_masks)
|
|
||||||
attn_sums: Tensor[float32]
|
|
||||||
Shape: (batch, num_heads, num_q_blocks, num_k_blocks_per_chunk)
|
|
||||||
Contains aggregated attention weights per block
|
|
||||||
|
|
||||||
simple_masks: Tensor[bool]
|
|
||||||
Shape: (batch, num_heads, num_q_blocks, num_k_blocks)
|
|
||||||
Boolean mask indicating which blocks to compute
|
|
||||||
```
|
|
||||||
|
|
||||||
### Algorithm
|
|
||||||
|
|
||||||
#### Step 1: Padding and Chunking
|
|
||||||
```python
|
|
||||||
# Pad sequences to chunk_size boundaries
|
|
||||||
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
|
|
||||||
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
|
|
||||||
|
|
||||||
# Compute number of blocks and chunks
|
|
||||||
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
|
|
||||||
k_block_num = (k_len + k_num_to_pad) // block_size
|
|
||||||
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
|
|
||||||
q_block_num = (q_len + q_num_to_pad) // block_size
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Step 2: Pattern Selection (stride-based downsampling)
|
|
||||||
|
|
||||||
**Purpose**: Reduce computation by `stride` factor using patterned selection
|
|
||||||
|
|
||||||
**Modes**:
|
|
||||||
1. **`"inverse"`** (default): Inverse stride pattern
|
|
||||||
```python
|
|
||||||
# Key: regular stride [0, stride, 2*stride, ...]
|
|
||||||
# Query: reverse stride [(stride-1), (stride-1-stride), ...]
|
|
||||||
reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
|
|
||||||
reshaped_query = torch.cat([query_states[:, :, (stride-1-q)::stride*kdb, :] for q in range(stride)])
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **`"slash"`**: Slash pattern (diagonal)
|
|
||||||
```python
|
|
||||||
# Both use regular stride
|
|
||||||
reshaped_key = torch.cat([key_states[:, :, k::stride, :] for k in range(stride)])
|
|
||||||
reshaped_query = torch.cat([query_states[:, :, q::stride, :] for q in range(stride)])
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **`"random"`**: Random permutation
|
|
||||||
4. **`"double"`, `"triple"`**: Data augmentation modes
|
|
||||||
|
|
||||||
#### Step 3: Chunk-wise Attention Estimation
|
|
||||||
|
|
||||||
For each query chunk:
|
|
||||||
|
|
||||||
**If `use_triton=True`** (fast path):
|
|
||||||
```python
|
|
||||||
# Triton kernel 1: Compute attention scores with fused reshape
|
|
||||||
attn_weights_slice = flat_group_gemm_fuse_reshape(
|
|
||||||
query_chunk, key_states, stride,
|
|
||||||
chunk_start, chunk_end, is_causal=causal
|
|
||||||
)
|
|
||||||
|
|
||||||
# Triton kernel 2: Softmax + block aggregation
|
|
||||||
attn_sum = softmax_fuse_block_sum(
|
|
||||||
attn_weights_slice, reshaped_block_size, segment_size,
|
|
||||||
chunk_start, chunk_end, real_q_len, scale, is_causal
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**If `use_triton=False`** (PyTorch fallback):
|
|
||||||
```python
|
|
||||||
# Standard matrix multiplication
|
|
||||||
attn_weights_slice = torch.matmul(chunked_query, reshaped_key.transpose(2, 3))
|
|
||||||
|
|
||||||
# Scale and apply causal mask
|
|
||||||
attn_weights_slice = attn_weights_slice / sqrt(head_dim) / stride / norm
|
|
||||||
attn_weights_slice = attn_weights_slice + causal_mask
|
|
||||||
|
|
||||||
# Softmax
|
|
||||||
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1)
|
|
||||||
|
|
||||||
# Aggregate to block level
|
|
||||||
attn_sum = attn_weights_slice.view(
|
|
||||||
batch, heads, num_blocks_per_chunk, block_size//kdb, -1, block_size
|
|
||||||
).sum(dim=-1).sum(dim=-2)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Step 4: Block Selection
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Select blocks based on threshold
|
|
||||||
simple_mask = find_blocks_chunked(
|
|
||||||
attn_sum,
|
|
||||||
current_index, # Starting block index
|
|
||||||
threshold, # 0.9 = select blocks covering 90% of attention mass
|
|
||||||
None, # or num_to_choose for top-k selection
|
|
||||||
decoding=False,
|
|
||||||
mode="prefill",
|
|
||||||
causal=True
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Selection Algorithm** (`find_blocks_chunked`):
|
|
||||||
1. Sort blocks by attention weight (descending)
|
|
||||||
2. Compute cumulative sum
|
|
||||||
3. Select blocks until `cumulative_sum >= total_sum * threshold`
|
|
||||||
4. Enforce causal constraints (no future blocks)
|
|
||||||
5. Always include sink token (first block) if `keep_sink=True`
|
|
||||||
6. Always include diagonal blocks if `keep_recent=True`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. Function: `Xattention_prefill()`
|
|
||||||
|
|
||||||
**Purpose**: Compute sparse attention using estimated block mask
|
|
||||||
|
|
||||||
### Input Parameters
|
|
||||||
|
|
||||||
| Parameter | Type | Default | Description |
|
|
||||||
|-----------|------|---------|-------------|
|
|
||||||
| `query_states` | Tensor | - | `(batch, num_heads, q_len, head_dim)` |
|
|
||||||
| `key_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` |
|
|
||||||
| `value_states` | Tensor | - | `(batch, num_heads, k_len, head_dim)` |
|
|
||||||
| `stride` | int | - | Downsampling stride for estimation |
|
|
||||||
| `norm` | float | 1 | Normalization factor |
|
|
||||||
| `threshold` | float | 0.8 | Block selection threshold |
|
|
||||||
| `block_size` | int | 128 | **MUST be 128** (hardcoded requirement) |
|
|
||||||
| `use_triton` | bool | True | Use Triton kernels in estimation |
|
|
||||||
| `causal` | bool | True | Apply causal masking |
|
|
||||||
| `kdb` | int | 1 | Key downsampling factor |
|
|
||||||
| `chunk_size` | int | None | Auto-computed if None |
|
|
||||||
| `keep_sink` | bool | False | Always attend to first token |
|
|
||||||
| `keep_recent` | bool | False | Always attend to recent tokens |
|
|
||||||
|
|
||||||
### Output
|
|
||||||
|
|
||||||
```python
|
|
||||||
returns: attn_output
|
|
||||||
attn_output: Tensor
|
|
||||||
Shape: (batch, num_heads, q_len, head_dim)
|
|
||||||
Sparse attention output
|
|
||||||
```
|
|
||||||
|
|
||||||
### Algorithm Flow
|
|
||||||
|
|
||||||
#### Step 1: Auto-compute chunk_size
|
|
||||||
```python
|
|
||||||
if chunk_size is None:
|
|
||||||
chunk_size = int(max(
|
|
||||||
min(
|
|
||||||
max(2048, 1 << (k_len - 1).bit_length()), # Round to power of 2
|
|
||||||
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()), # Memory constraint
|
|
||||||
),
|
|
||||||
2048, # Minimum
|
|
||||||
))
|
|
||||||
```
|
|
||||||
|
|
||||||
**Example**:
|
|
||||||
- `k_len=8192` → `chunk_size=8192`
|
|
||||||
- `k_len=32768` → `chunk_size=16384`
|
|
||||||
- `k_len=65536` → `chunk_size=16384`
|
|
||||||
|
|
||||||
#### Step 2: Estimate attention and select blocks
|
|
||||||
```python
|
|
||||||
attn_sums, approx_simple_mask = xattn_estimate(
|
|
||||||
query_states, key_states,
|
|
||||||
block_size=block_size, stride=stride, norm=norm,
|
|
||||||
threshold=threshold, select_mode="inverse",
|
|
||||||
use_triton=use_triton, causal=causal,
|
|
||||||
chunk_size=chunk_size, kdb=kdb,
|
|
||||||
keep_sink=keep_sink, keep_recent=keep_recent
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Step 3: Prepare inputs for block_sparse_attn_func
|
|
||||||
```python
|
|
||||||
# Hard constraints
|
|
||||||
assert block_size == 128
|
|
||||||
assert batch_size == 1
|
|
||||||
|
|
||||||
# Reshape to (seq_len, num_heads, head_dim)
|
|
||||||
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
|
|
||||||
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
|
||||||
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
|
||||||
|
|
||||||
# Cumulative sequence lengths
|
|
||||||
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
|
|
||||||
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)
|
|
||||||
|
|
||||||
# Head mask type (all heads use mask)
|
|
||||||
head_mask_type = torch.tensor([1 for _ in range(num_heads)], dtype=torch.int32)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Step 4: Call block_sparse_attn_func
|
|
||||||
```python
|
|
||||||
attn_output = block_sparse_attn_func(
|
|
||||||
query_states, # (q_len, num_heads, head_dim)
|
|
||||||
key_states, # (k_len, num_heads, head_dim)
|
|
||||||
value_states, # (k_len, num_heads, head_dim)
|
|
||||||
q_cu_seq_lens, # [0, q_len]
|
|
||||||
k_cu_seq_lens, # [0, k_len]
|
|
||||||
head_mask_type, # [1, 1, ..., 1]
|
|
||||||
None, # No custom layout
|
|
||||||
approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(), # Block mask
|
|
||||||
q_len,
|
|
||||||
k_len,
|
|
||||||
p_dropout=0.0,
|
|
||||||
deterministic=True,
|
|
||||||
is_causal=causal
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Step 5: Reshape output
|
|
||||||
```python
|
|
||||||
attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
|
||||||
# Output shape: (batch, num_heads, q_len, head_dim)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. Triton Kernel Dependencies
|
|
||||||
|
|
||||||
### Kernel 1: `flat_group_gemm_fuse_reshape_kernel`
|
|
||||||
|
|
||||||
**Purpose**: Compute QK^T with stride-based reshaping
|
|
||||||
|
|
||||||
**Key Features**:
|
|
||||||
- Loads `stride` keys and queries at once
|
|
||||||
- Fused strided access pattern
|
|
||||||
- Causal masking support
|
|
||||||
- Block size auto-selection based on GPU memory
|
|
||||||
|
|
||||||
**Block Size Selection**:
|
|
||||||
```python
|
|
||||||
# RTX 3090 (<30GB): BLOCK_M=64, BLOCK_N=64
|
|
||||||
# A100/H100 (>=30GB): BLOCK_M=128, BLOCK_N=128
|
|
||||||
```
|
|
||||||
|
|
||||||
**Signature**:
|
|
||||||
```python
|
|
||||||
flat_group_gemm_fuse_reshape(
|
|
||||||
query_states, # (batch, heads, q_len, head_dim)
|
|
||||||
key_states, # (batch, heads, k_len, head_dim)
|
|
||||||
stride, # Downsampling factor
|
|
||||||
chunk_start, # Start position in keys
|
|
||||||
chunk_end, # End position in keys
|
|
||||||
is_causal=True
|
|
||||||
)
|
|
||||||
# Returns: (batch, heads, q_len//stride, k_len//stride)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Kernel 2: `softmax_fuse_block_sum_kernel_causal` / `_non_causal`
|
|
||||||
|
|
||||||
**Purpose**: Online softmax with block aggregation
|
|
||||||
|
|
||||||
**Algorithm**:
|
|
||||||
1. **Forward pass** (compute m_i, l_i):
|
|
||||||
```
|
|
||||||
m_i = max(m_i, m_local)
|
|
||||||
alpha = exp(m_i - m_new)
|
|
||||||
l_i = l_i * alpha + sum(exp(X - m_new))
|
|
||||||
```
|
|
||||||
2. **Backward pass** (compute softmax with scaling):
|
|
||||||
```
|
|
||||||
softmax = exp(X - m_i) / l_i
|
|
||||||
aggregate to blocks: sum(softmax) over block_size
|
|
||||||
```
|
|
||||||
|
|
||||||
**Key Features**:
|
|
||||||
- Single-pass softmax (no materializing full attention matrix)
|
|
||||||
- Causal masking integrated
|
|
||||||
- Outputs block-level sums directly
|
|
||||||
|
|
||||||
**Signature**:
|
|
||||||
```python
|
|
||||||
softmax_fuse_block_sum(
|
|
||||||
attn_weights_slice, # (batch, heads, q_len, k_len)
|
|
||||||
reshaped_block_size, # Block size (128//stride)
|
|
||||||
segment_size, # Processing segment (min(4096, block_size))
|
|
||||||
chunk_start, # Start position
|
|
||||||
chunk_end, # End position
|
|
||||||
real_q_len, # Actual query length (before padding)
|
|
||||||
scale, # 1.4426950408889634 / sqrt(head_dim) / stride / norm
|
|
||||||
is_causal=True
|
|
||||||
)
|
|
||||||
# Returns: (batch, heads, q_len//block_size, k_len//block_size)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. Key Parameters and Their Meanings
|
|
||||||
|
|
||||||
### Critical Parameters
|
|
||||||
|
|
||||||
| Parameter | Meaning | Typical Value | Impact |
|
|
||||||
|-----------|---------|---------------|--------|
|
|
||||||
| `block_size` | Block granularity | 128 | **Fixed at 128**, affects mask granularity |
|
|
||||||
| `stride` | Downsampling factor | 4-16 | Higher = faster but less accurate |
|
|
||||||
| `threshold` | Sparsity level | 0.8-0.9 | Higher = denser mask, more computation |
|
|
||||||
| `chunk_size` | Processing chunk | 16384 | Affects memory and efficiency |
|
|
||||||
| `kdb` | Key downsampling boost | 1 | Experimental, use 1 |
|
|
||||||
| `norm` | Scaling factor | 1.0 | Attention temperature control |
|
|
||||||
|
|
||||||
### Trade-offs
|
|
||||||
|
|
||||||
**Stride (`stride`)**:
|
|
||||||
- `stride=1`: No approximation, same as dense attention
|
|
||||||
- `stride=4`: 4x faster estimation, good accuracy
|
|
||||||
- `stride=8`: 8x faster, moderate accuracy loss
|
|
||||||
- `stride=16`: 16x faster, significant accuracy loss
|
|
||||||
|
|
||||||
**Threshold (`threshold`)**:
|
|
||||||
- `threshold=0.8`: Select blocks covering 80% of attention mass (~20% sparsity)
|
|
||||||
- `threshold=0.9`: Select blocks covering 90% of attention mass (~10% sparsity)
|
|
||||||
- `threshold=0.95`: Very dense, only prunes ~5% of blocks
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. Dependencies
|
|
||||||
|
|
||||||
### Required Libraries
|
|
||||||
|
|
||||||
1. **`block_sparse_attn`** (CRITICAL)
|
|
||||||
- Source: `/home/zijie/Code/COMPASS/3rdparty/flash-attention/`
|
|
||||||
- Function: `block_sparse_attn_func`
|
|
||||||
- Type: **C++ CUDA extension**
|
|
||||||
- Build: Requires compilation with `torch.utils.cpp_extension`
|
|
||||||
|
|
||||||
2. **Triton** (optional but recommended)
|
|
||||||
- Required for: `use_triton=True`
|
|
||||||
- GPU requirement: SM 80+ (A100, RTX 3090, H100, etc.)
|
|
||||||
- Check: `torch.cuda.get_device_properties().major >= 8`
|
|
||||||
|
|
||||||
3. **PyTorch**
|
|
||||||
- Version: Compatible with flash-attention
|
|
||||||
- Features: F.pad, matmul, softmax, view, transpose
|
|
||||||
|
|
||||||
### Dependency Tree
|
|
||||||
|
|
||||||
```
|
|
||||||
Xattention_prefill
|
|
||||||
├── xattn_estimate
|
|
||||||
│ ├── flat_group_gemm_fuse_reshape (Triton)
|
|
||||||
│ ├── softmax_fuse_block_sum (Triton)
|
|
||||||
│ └── find_blocks_chunked (PyTorch)
|
|
||||||
└── block_sparse_attn_func (C++ CUDA)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. Integration Issues for nano-vllm
|
|
||||||
|
|
||||||
### Critical Issue 1: `block_sparse_attn_func` Dependency
|
|
||||||
|
|
||||||
**Problem**: `block_sparse_attn_func` is a **C++ CUDA extension** that must be compiled from flash-attention source.
|
|
||||||
|
|
||||||
**Options**:
|
|
||||||
1. **Compile flash-attention with block sparse support**
|
|
||||||
```bash
|
|
||||||
cd /home/zijie/Code/COMPASS/3rdparty/flash-attention
|
|
||||||
python setup.py install
|
|
||||||
```
|
|
||||||
- Risk: May conflict with existing flash-attention installation
|
|
||||||
- Complexity: High (C++ compilation)
|
|
||||||
|
|
||||||
2. **Replace with FlashInfer block sparse**
|
|
||||||
- FlashInfer is already a dependency
|
|
||||||
- Has similar block sparse attention
|
|
||||||
- Need to adapt interface
|
|
||||||
|
|
||||||
3. **Custom CUDA kernel**
|
|
||||||
- Implement simplified block sparse attention
|
|
||||||
- High development cost
|
|
||||||
- Maintenance burden
|
|
||||||
|
|
||||||
### Critical Issue 2: Hard-coded Constraints
|
|
||||||
|
|
||||||
```python
|
|
||||||
assert block_size == 128 # Line 358
|
|
||||||
assert batch_size == 1 # Line 359
|
|
||||||
```
|
|
||||||
|
|
||||||
**Impact**:
|
|
||||||
- Cannot process multiple sequences in one batch
|
|
||||||
- Fixed block size limits flexibility
|
|
||||||
- Must work around these constraints
|
|
||||||
|
|
||||||
### Critical Issue 3: Triton GPU Requirement
|
|
||||||
|
|
||||||
```python
|
|
||||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
||||||
if props.major < 8:
|
|
||||||
use_triton = False
|
|
||||||
```
|
|
||||||
|
|
||||||
**Impact**:
|
|
||||||
- Triton kernels only work on SM 80+ (A100, RTX 3090, H100)
|
|
||||||
- Older GPUs (V100, T4, RTX 2080) fall back to slow PyTorch implementation
|
|
||||||
- RTX 3090 works but uses smaller block sizes (64 vs 128)
|
|
||||||
|
|
||||||
### Issue 4: Memory Layout
|
|
||||||
|
|
||||||
**XAttention expects**:
|
|
||||||
```python
|
|
||||||
query_states: (batch, num_heads, q_len, head_dim)
|
|
||||||
```
|
|
||||||
|
|
||||||
**nano-vllm uses**:
|
|
||||||
```python
|
|
||||||
query_states: (num_heads, total_tokens, head_dim) # Flattened batch
|
|
||||||
```
|
|
||||||
|
|
||||||
**Required**: Transpose and reshape before/after calling XAttention
|
|
||||||
|
|
||||||
### Issue 5: Chunking Incompatibility
|
|
||||||
|
|
||||||
**XAttention**: Processes in fixed-size chunks (e.g., 16384 tokens)
|
|
||||||
- Requires padding to chunk boundaries
|
|
||||||
- Adds overhead for short sequences
|
|
||||||
|
|
||||||
**nano-vllm**: Processes variable-length requests
|
|
||||||
- No padding requirement
|
|
||||||
- Dynamic batch sizing
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 7. Integration Strategy
|
|
||||||
|
|
||||||
### Recommended Approach: **Wrapper with FlashInfer**
|
|
||||||
|
|
||||||
1. **Keep `xattn_estimate`** (pure PyTorch + Triton)
|
|
||||||
- No external dependencies
|
|
||||||
- Computes block mask
|
|
||||||
|
|
||||||
2. **Replace `block_sparse_attn_func` with FlashInfer**
|
|
||||||
- FlashInfer: `flashinfer.single_prefill_with_kv_cache`
|
|
||||||
- Similar API, already compiled
|
|
||||||
- Supports block sparse
|
|
||||||
|
|
||||||
3. **Adapt mask format**
|
|
||||||
- XAttention: `(batch, heads, q_blocks, k_blocks)` boolean mask
|
|
||||||
- FlashInfer: `(num_qo, num_kv)` boolean mask or custom format
|
|
||||||
|
|
||||||
4. **Handle constraints**
|
|
||||||
- Enforce `batch_size=1` by processing one request at a time
|
|
||||||
- Keep `block_size=128` as requirement
|
|
||||||
|
|
||||||
### Alternative: **Pure PyTorch Implementation**
|
|
||||||
|
|
||||||
1. Extract estimation algorithm
|
|
||||||
2. Implement sparse attention using PyTorch operations
|
|
||||||
3. Use FlashInfer for final computation
|
|
||||||
4. No Triton dependency
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 8. Code Example: Adaptation
|
|
||||||
|
|
||||||
```python
|
|
||||||
def xattention_prefill_adapted(
|
|
||||||
query_states, # (num_heads, q_len, head_dim)
|
|
||||||
key_states, # (num_heads, k_len, head_dim)
|
|
||||||
value_states, # (num_heads, k_len, head_dim)
|
|
||||||
stride=4,
|
|
||||||
threshold=0.9,
|
|
||||||
block_size=128,
|
|
||||||
causal=True,
|
|
||||||
):
|
|
||||||
# Step 1: Add batch dimension
|
|
||||||
q = query_states.unsqueeze(0) # (1, heads, q_len, dim)
|
|
||||||
k = key_states.unsqueeze(0)
|
|
||||||
v = value_states.unsqueeze(0)
|
|
||||||
|
|
||||||
# Step 2: Estimate mask (no external dependency)
|
|
||||||
_, block_mask = xattn_estimate(
|
|
||||||
q, k,
|
|
||||||
block_size=block_size,
|
|
||||||
stride=stride,
|
|
||||||
threshold=threshold,
|
|
||||||
use_triton=True,
|
|
||||||
causal=causal,
|
|
||||||
)
|
|
||||||
# block_mask: (1, heads, q_blocks, k_blocks)
|
|
||||||
|
|
||||||
# Step 3: Convert block mask to token mask
|
|
||||||
q_blocks, k_blocks = block_mask.shape[-2:]
|
|
||||||
token_mask = block_mask.repeat_interleave(block_size, dim=-2)
|
|
||||||
token_mask = token_mask.repeat_interleave(block_size, dim=-1)
|
|
||||||
token_mask = token_mask[:, :, :q.size(2), :k.size(2)] # Trim padding
|
|
||||||
|
|
||||||
# Step 4: Use FlashInfer with mask
|
|
||||||
from flashinfer import single_prefill_with_kv_cache
|
|
||||||
output = single_prefill_with_kv_cache(
|
|
||||||
q.squeeze(0),
|
|
||||||
k.squeeze(0),
|
|
||||||
v.squeeze(0),
|
|
||||||
custom_mask=token_mask.squeeze(0),
|
|
||||||
)
|
|
||||||
|
|
||||||
return output # (num_heads, q_len, head_dim)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 9. Summary of Findings
|
|
||||||
|
|
||||||
### Advantages
|
|
||||||
|
|
||||||
1. **Accurate approximation**: Pattern-based stride selection preserves attention patterns
|
|
||||||
2. **Flexible sparsity**: Threshold-based control over computation
|
|
||||||
3. **GPU optimization**: Triton kernels for estimation phase
|
|
||||||
4. **Proven in practice**: Used in COMPASS system
|
|
||||||
|
|
||||||
### Challenges
|
|
||||||
|
|
||||||
1. **Hard dependency**: `block_sparse_attn_func` requires C++ compilation
|
|
||||||
2. **Rigid constraints**: `block_size=128`, `batch_size=1`
|
|
||||||
3. **GPU-specific**: Triton only on SM 80+
|
|
||||||
4. **Memory layout mismatch**: Requires reshape/transpose
|
|
||||||
5. **Chunking overhead**: Padding to chunk boundaries
|
|
||||||
|
|
||||||
### Integration Complexity
|
|
||||||
|
|
||||||
| Component | Complexity | Risk |
|
|
||||||
|-----------|------------|------|
|
|
||||||
| `xattn_estimate` | Medium | Low (PyTorch + Triton) |
|
|
||||||
| `block_sparse_attn_func` | High | **Critical** (C++ dependency) |
|
|
||||||
| Interface adaptation | Low | Low (reshape) |
|
|
||||||
| Constraint handling | Medium | Medium (workarounds) |
|
|
||||||
|
|
||||||
**Overall Integration Risk**: **HIGH** (due to C++ dependency)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 10. Next Steps
|
|
||||||
|
|
||||||
1. **Evaluate FlashInfer compatibility**
|
|
||||||
- Can FlashInfer replace `block_sparse_attn_func`?
|
|
||||||
- What mask format does it expect?
|
|
||||||
|
|
||||||
2. **Prototype estimation phase**
|
|
||||||
- Extract `xattn_estimate` function
|
|
||||||
- Test with nano-vllm inputs
|
|
||||||
- Validate mask quality
|
|
||||||
|
|
||||||
3. **Benchmark Triton kernels**
|
|
||||||
- Compare Triton vs PyTorch estimation
|
|
||||||
- Measure speedup on RTX 3090
|
|
||||||
- Profile memory usage
|
|
||||||
|
|
||||||
4. **Design interface**
|
|
||||||
- Define nano-vllm sparse attention API
|
|
||||||
- Specify mask format
|
|
||||||
- Plan integration points
|
|
||||||
@@ -1,961 +0,0 @@
|
|||||||
# XAttention 集成指南
|
|
||||||
|
|
||||||
本文档详细记录了将 COMPASS 的 XAttention 算法集成到 nano-vllm 的完整过程,包括算法原理、源码分析、设计决策、实现细节和测试验证。
|
|
||||||
|
|
||||||
## 目录
|
|
||||||
|
|
||||||
1. [背景](#1-背景)
|
|
||||||
2. [XAttention 算法原理](#2-xattention-算法原理)
|
|
||||||
3. [COMPASS 源码分析](#3-compass-源码分析)
|
|
||||||
4. [集成设计决策](#4-集成设计决策)
|
|
||||||
5. [实现细节](#5-实现细节)
|
|
||||||
6. [问题与解决方案](#6-问题与解决方案)
|
|
||||||
7. [测试验证](#7-测试验证)
|
|
||||||
8. [使用指南](#8-使用指南)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. 背景
|
|
||||||
|
|
||||||
### 1.1 为什么需要 XAttention
|
|
||||||
|
|
||||||
- **长上下文推理需求**:随着 LLM 上下文长度扩展到 32k、64k 甚至更长,传统注意力机制的计算复杂度 O(n²) 成为瓶颈
|
|
||||||
- **COMPASS 算法**:通过 chunked estimation 和 block sparse attention 实现 O(n) 复杂度
|
|
||||||
- **nano-vllm 集成目标**:在 CPU offload 模式下支持高效的长上下文推理
|
|
||||||
|
|
||||||
### 1.2 集成范围
|
|
||||||
|
|
||||||
**仅关注 offload 执行路径**:
|
|
||||||
- `run_layerwise_offload_prefill()` - layer-wise chunked prefill
|
|
||||||
- CPU offload 模式下的 KV cache 管理
|
|
||||||
- 与 `SparsePolicy` 框架的集成
|
|
||||||
|
|
||||||
### 1.3 参考
|
|
||||||
|
|
||||||
- COMPASS 源码:`/home/zijie/Code/COMPASS/compass/src/`
|
|
||||||
- 关键文件:`Xattention.py`, `kernels.py`, `utils.py`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. XAttention 算法原理
|
|
||||||
|
|
||||||
### 2.1 两阶段设计
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────────────────────────────────┐
|
|
||||||
│ XAttention 流程 │
|
|
||||||
├─────────────────────────────────────────────────────────────┤
|
|
||||||
│ │
|
|
||||||
│ Phase 1: Chunked Estimation │
|
|
||||||
│ ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ │
|
|
||||||
│ │ Query Chunk │ -> │ Triton GEMM │ -> │ Attn Scores │ │
|
|
||||||
│ │ (stride=8) │ │ (fused) │ │ (per block) │ │
|
|
||||||
│ └─────────────┘ └──────────────┘ └─────────────┘ │
|
|
||||||
│ ↓ │
|
|
||||||
│ ┌─────────────┐ │
|
|
||||||
│ │ Block Mask │ │
|
|
||||||
│ │ (threshold) │ │
|
|
||||||
│ └─────────────┘ │
|
|
||||||
│ │
|
|
||||||
│ Phase 2: Block Sparse Attention │
|
|
||||||
│ ┌─────────────┐ ┌──────────────┐ ┌─────────────┐ │
|
|
||||||
│ │ Selected Q │ -> │ Block Sparse │ -> │ Output │ │
|
|
||||||
│ │ + Selected K│ │ Attention │ │ │ │
|
|
||||||
│ └─────────────┘ └──────────────┘ └─────────────┘ │
|
|
||||||
│ │
|
|
||||||
└─────────────────────────────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2.2 关键参数
|
|
||||||
|
|
||||||
| 参数 | 默认值 | 说明 |
|
|
||||||
|------|--------|------|
|
|
||||||
| `stride` | 8 | Q/K 重组步长 |
|
|
||||||
| `block_size` | 128 | Block 大小(tokens) |
|
|
||||||
| `threshold` | 0.9 | Block 选择阈值 (0-1) |
|
|
||||||
| `chunk_size` | 16384 | Estimation chunk 大小 |
|
|
||||||
|
|
||||||
### 2.3 计算流程
|
|
||||||
|
|
||||||
1. **Chunked Estimation**:
|
|
||||||
- 将 Q 分成固定大小的 chunks
|
|
||||||
- 使用 Triton kernels 计算 QK^T(fused GEMM + reshape)
|
|
||||||
- 分块 softmax 并聚合到 block 级别
|
|
||||||
- 根据阈值选择重要 blocks
|
|
||||||
|
|
||||||
2. **Block Sparse Attention**:
|
|
||||||
- 只计算选中 blocks 的注意力
|
|
||||||
- 使用 block sparse kernels 优化
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. COMPASS 源码分析
|
|
||||||
|
|
||||||
### 3.1 核心文件结构
|
|
||||||
|
|
||||||
```
|
|
||||||
COMPASS/compass/src/
|
|
||||||
├── Xattention.py # XAttention 主算法
|
|
||||||
├── kernels.py # Triton kernels
|
|
||||||
├── utils.py # 辅助函数
|
|
||||||
└── block_sparse.py # Block sparse attention
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.2 Xattention.py 分析
|
|
||||||
|
|
||||||
**核心函数**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def xattn_estimate(
|
|
||||||
query_states, key_states, value_states,
|
|
||||||
stride, block_size, threshold, ...
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Phase 1: 估算稀疏注意力模式
|
|
||||||
|
|
||||||
返回:
|
|
||||||
attn_sums: [batch, heads, q_blocks, k_blocks] 重要性分数
|
|
||||||
simple_masks: [batch, heads, q_blocks, k_blocks] 布尔掩码
|
|
||||||
"""
|
|
||||||
# 1. Pad inputs to chunk_size multiples
|
|
||||||
# 2. Reshape with stride
|
|
||||||
# 3. Compute QK^T in chunks (Triton)
|
|
||||||
# 4. Block-wise softmax + aggregation
|
|
||||||
# 5. Threshold-based selection
|
|
||||||
return attn_sums, simple_masks
|
|
||||||
|
|
||||||
|
|
||||||
def Xattention_prefill(
|
|
||||||
query_states, key_states, value_states,
|
|
||||||
stride, threshold, ...
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
完整 XAttention prefill
|
|
||||||
|
|
||||||
流程:
|
|
||||||
1. xattn_estimate() - 获取 block mask
|
|
||||||
2. block_sparse_attn_func() - 稀疏注意力计算
|
|
||||||
"""
|
|
||||||
attn_sums, simple_masks = xattn_estimate(...)
|
|
||||||
attn_output = block_sparse_attn_func(
|
|
||||||
query_states, key_states, value_states,
|
|
||||||
simple_masks, block_size
|
|
||||||
)
|
|
||||||
return attn_output
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.3 kernels.py 分析
|
|
||||||
|
|
||||||
**Triton Kernels**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
@triton.jit
|
|
||||||
def flat_group_gemm_fuse_reshape_kernel(Q, K, Out, ...):
|
|
||||||
"""
|
|
||||||
Stride-based GEMM with reshape fusion
|
|
||||||
|
|
||||||
关键优化:
|
|
||||||
- Stride 访问模式:每隔 stride 个 token 访问一次
|
|
||||||
- Fused reshape:避免单独的 reshape 操作
|
|
||||||
- Block-level 并行:M×N block tiling
|
|
||||||
"""
|
|
||||||
# Load Q and K with stride
|
|
||||||
for iter in range(STRIDE):
|
|
||||||
q = tl.load(Q_ptrs - iter * stride_qn)
|
|
||||||
k = tl.load(K_ptrs + iter * stride_kn)
|
|
||||||
o += tl.dot(q, k)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def softmax_fuse_block_sum_kernel_causal(In, Out, ...):
|
|
||||||
"""
|
|
||||||
Block-wise softmax with sum aggregation
|
|
||||||
|
|
||||||
关键优化:
|
|
||||||
- Online softmax:避免存储完整注意力矩阵
|
|
||||||
- Block sum:聚合到 block 级别
|
|
||||||
- Causal mask:支持因果注意力
|
|
||||||
"""
|
|
||||||
# Online softmax (m_i, l_i)
|
|
||||||
m_new = tl.maximum(m_i, m_local)
|
|
||||||
alpha = tl.math.exp2(m_i - m_new)
|
|
||||||
l_i = l_i * alpha + l_local
|
|
||||||
m_i = m_new
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3.4 utils.py 分析
|
|
||||||
|
|
||||||
**关键函数**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def find_blocks_chunked(
|
|
||||||
input_tensor, # [batch, heads, chunk_q, block_k]
|
|
||||||
current_index,
|
|
||||||
threshold, # 0-1
|
|
||||||
num_to_choose,
|
|
||||||
decoding,
|
|
||||||
mode,
|
|
||||||
causal
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
基于阈值选择重要 blocks
|
|
||||||
|
|
||||||
返回:
|
|
||||||
boolean mask: [batch, heads, chunk_q, block_k]
|
|
||||||
"""
|
|
||||||
# 1. 计算阈值分数
|
|
||||||
score_threshold = input_tensor.max() * threshold
|
|
||||||
|
|
||||||
# 2. 生成布尔掩码
|
|
||||||
masks = (input_tensor >= score_threshold)
|
|
||||||
|
|
||||||
# 3. 应用因果约束
|
|
||||||
if causal:
|
|
||||||
# 只保留下三角区域
|
|
||||||
...
|
|
||||||
|
|
||||||
return masks
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. 集成设计决策
|
|
||||||
|
|
||||||
### 4.1 稀疏策略框架
|
|
||||||
|
|
||||||
nano-vllm 使用 `SparsePolicy` 抽象接口:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SparsePolicy(ABC):
|
|
||||||
"""稀疏注意力策略基类"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_prefill(self) -> bool:
|
|
||||||
"""是否支持 prefill 阶段"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def supports_decode(self) -> bool:
|
|
||||||
"""是否支持 decode 阶段"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def requires_block_selection(self) -> bool:
|
|
||||||
"""是否需要 block selection(用于 KV cache 加载)"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def select_blocks(self, available_blocks, ctx) -> List[int]:
|
|
||||||
"""选择要加载的 KV blocks"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def sparse_prefill_attention(self, q, k, v, layer_id) -> torch.Tensor:
|
|
||||||
"""计算稀疏 prefill 注意力"""
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4.2 XAttention 设计决策
|
|
||||||
|
|
||||||
#### 决策 1:Prefill-Only 策略
|
|
||||||
|
|
||||||
```python
|
|
||||||
class XAttentionPolicy(SparsePolicy):
|
|
||||||
supports_prefill = True
|
|
||||||
supports_decode = False # XAttention 仅用于 prefill
|
|
||||||
requires_block_selection = False # 不影响 KV cache 加载
|
|
||||||
```
|
|
||||||
|
|
||||||
**原因**:
|
|
||||||
- XAttention 是 prefill 阶段的优化算法
|
|
||||||
- Decode 阶段使用其他策略(如 QUEST)
|
|
||||||
- Block selection 不在 XAttention 范围内
|
|
||||||
|
|
||||||
#### 决策 2:CPU Offload 模式简化
|
|
||||||
|
|
||||||
```python
|
|
||||||
def sparse_prefill_attention(self, q, k, v, layer_id):
|
|
||||||
# 使用 FlashAttention 直接计算
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=seq_len,
|
|
||||||
max_seqlen_k=seq_len,
|
|
||||||
softmax_scale=1.0 / math.sqrt(head_dim),
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
return attn_output
|
|
||||||
```
|
|
||||||
|
|
||||||
**关键原因**:
|
|
||||||
|
|
||||||
1. **Chunked Prefill 架构限制**:
|
|
||||||
```
|
|
||||||
Offload 模式: run_layerwise_offload_prefill()
|
|
||||||
└─ 每次只处理一个 chunk (2048 tokens)
|
|
||||||
└─ 完整的 key_states 在 CPU,不在当前调用栈
|
|
||||||
└─ 无法进行完整的 chunked estimation
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Estimation 需要完整上下文**:
|
|
||||||
- XAttention 的 estimation 需要访问完整 key_states
|
|
||||||
- Offload 模式下 keys 分层存储在 CPU
|
|
||||||
- 传递所有 keys 会破坏 offload 的内存优势
|
|
||||||
|
|
||||||
3. **FlashAttention 原生支持 GQA**:
|
|
||||||
- GQA (Grouped Query Attention): num_kv_heads < num_heads
|
|
||||||
- FlashAttention 自动处理 head 展开
|
|
||||||
- 避免手动实现的复杂性
|
|
||||||
|
|
||||||
#### 决策 3:保留 Triton Kernels
|
|
||||||
|
|
||||||
虽然 CPU offload 模式使用 FlashAttention,但仍保留 Triton kernels:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# nanovllm/kvcache/sparse/kernels.py
|
|
||||||
# 保留完整的 Triton 实现,供未来 GPU-only 模式使用
|
|
||||||
|
|
||||||
def softmax_fuse_block_sum(attn_weights_slice, ...):
|
|
||||||
"""Triton softmax + block sum wrapper"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def flat_group_gemm_fuse_reshape(query_states, key_states, ...):
|
|
||||||
"""Triton GEMM + reshape wrapper"""
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
**原因**:
|
|
||||||
- 未来可以支持 GPU-only 模式的完整 XAttention
|
|
||||||
- Triton kernels 已实现,无需删除
|
|
||||||
- 保持代码完整性
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. 实现细节
|
|
||||||
|
|
||||||
### 5.1 文件结构
|
|
||||||
|
|
||||||
```
|
|
||||||
nanovllm/kvcache/sparse/
|
|
||||||
├── __init__.py # 策略注册
|
|
||||||
├── policy.py # 基类定义
|
|
||||||
├── full_policy.py # Full attention 策略
|
|
||||||
├── quest.py # Quest 策略
|
|
||||||
├── minference.py # MInference 策略
|
|
||||||
├── xattn.py # XAttention 策略(新增)
|
|
||||||
├── utils.py # 工具函数(新增)
|
|
||||||
└── kernels.py # Triton kernels(新增)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5.2 utils.py 实现
|
|
||||||
|
|
||||||
```python
|
|
||||||
"""
|
|
||||||
Sparse attention utility functions.
|
|
||||||
Copied and adapted from COMPASS/compass/src/utils.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def find_blocks_chunked(
|
|
||||||
input_tensor,
|
|
||||||
current_index,
|
|
||||||
threshold,
|
|
||||||
num_to_choose,
|
|
||||||
decoding: bool,
|
|
||||||
mode: str = "both",
|
|
||||||
causal=True,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Select blocks based on threshold.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_tensor: [batch, heads, q_blocks, k_blocks] importance scores
|
|
||||||
current_index: Current chunk index
|
|
||||||
threshold: Block selection threshold (0-1)
|
|
||||||
num_to_choose: Number of blocks to choose (if None, use threshold)
|
|
||||||
decoding: Whether in decode mode
|
|
||||||
mode: Selection mode ("prefill", "decoding", "both")
|
|
||||||
causal: Apply causal mask
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
boolean mask: [batch, heads, q_blocks, k_blocks]
|
|
||||||
"""
|
|
||||||
batch_size, head_num, chunk_q, block_k = input_tensor.shape
|
|
||||||
|
|
||||||
if num_to_choose is None:
|
|
||||||
# Threshold-based selection
|
|
||||||
score_threshold = input_tensor.max() * threshold
|
|
||||||
masks = (input_tensor >= score_threshold)
|
|
||||||
else:
|
|
||||||
# Top-k selection
|
|
||||||
topk_values, _ = torch.topk(
|
|
||||||
input_tensor.flatten(start_dim=2),
|
|
||||||
k=num_to_choose,
|
|
||||||
dim=-1
|
|
||||||
)
|
|
||||||
score_threshold = topk_values[..., -1:].unsqueeze(-1)
|
|
||||||
masks = (input_tensor >= score_threshold)
|
|
||||||
|
|
||||||
# Causal mask
|
|
||||||
if causal and chunk_q > 1:
|
|
||||||
for q_idx in range(chunk_q):
|
|
||||||
k_start = current_index + q_idx
|
|
||||||
masks[:, :, q_idx, :k_start] = False
|
|
||||||
|
|
||||||
return masks
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5.3 kernels.py 实现
|
|
||||||
|
|
||||||
```python
|
|
||||||
"""
|
|
||||||
Triton kernels for XAttention sparse attention.
|
|
||||||
|
|
||||||
Copied and adapted from COMPASS/compass/src/kernels.py
|
|
||||||
|
|
||||||
Requirements:
|
|
||||||
- Triton >= 2.1.0
|
|
||||||
- CUDA compute capability SM 80+ (RTX 3090, A100, H100, etc.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def softmax_fuse_block_sum_kernel_causal(
|
|
||||||
In, Out, scale,
|
|
||||||
input_stride_0, input_stride_1, input_stride_2,
|
|
||||||
output_stride_0, output_stride_1, output_stride_2,
|
|
||||||
real_q_len, k_len, chunk_start, chunk_end,
|
|
||||||
segment_size: tl.constexpr,
|
|
||||||
block_size: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Causal softmax with block sum aggregation.
|
|
||||||
|
|
||||||
Online softmax algorithm:
|
|
||||||
m_i = max(m_i, m_new)
|
|
||||||
l_i = l_i * exp(m_i - m_new) + l_new
|
|
||||||
"""
|
|
||||||
block_id = tl.program_id(0)
|
|
||||||
head_id = tl.program_id(1)
|
|
||||||
batch_id = tl.program_id(2)
|
|
||||||
|
|
||||||
# ... (完整实现见源码)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def flat_group_gemm_fuse_reshape_kernel(
|
|
||||||
Q, K, Out,
|
|
||||||
stride_qz, stride_qh, stride_qn,
|
|
||||||
stride_kz, stride_kh, stride_kn,
|
|
||||||
stride_oz, stride_oh, stride_on,
|
|
||||||
chunk_start, chunk_end,
|
|
||||||
H: tl.constexpr,
|
|
||||||
STRIDE: tl.constexpr,
|
|
||||||
HEAD_DIM: tl.constexpr,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
is_causal: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Stride-based GEMM with reshape fusion.
|
|
||||||
"""
|
|
||||||
# ... (完整实现见源码)
|
|
||||||
|
|
||||||
|
|
||||||
def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size,
|
|
||||||
segment_size, chunk_start, chunk_end,
|
|
||||||
real_q_len, scale, is_causal=True):
|
|
||||||
"""Wrapper for Triton softmax-fuse-block-sum kernel."""
|
|
||||||
# ... (完整实现见源码)
|
|
||||||
|
|
||||||
|
|
||||||
def flat_group_gemm_fuse_reshape(query_states, key_states, stride,
|
|
||||||
chunk_start, chunk_end, is_causal=True):
|
|
||||||
"""Wrapper for Triton flat-group-gemm-fuse-reshape kernel."""
|
|
||||||
# ... (完整实现见源码)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5.4 xattn.py 实现
|
|
||||||
|
|
||||||
```python
|
|
||||||
"""
|
|
||||||
XAttention sparse attention policy for nano-vllm.
|
|
||||||
|
|
||||||
Implements the XAttention algorithm from COMPASS, using chunked estimation
|
|
||||||
and block sparse attention for efficient long-context inference.
|
|
||||||
|
|
||||||
Reference: COMPASS/compass/src/Xattention.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import List, Optional
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
|
||||||
from nanovllm.kvcache.sparse.kernels import (
|
|
||||||
flat_group_gemm_fuse_reshape,
|
|
||||||
softmax_fuse_block_sum,
|
|
||||||
)
|
|
||||||
from nanovllm.kvcache.sparse.utils import find_blocks_chunked
|
|
||||||
|
|
||||||
|
|
||||||
class XAttentionPolicy(SparsePolicy):
|
|
||||||
"""
|
|
||||||
XAttention sparse prefill policy using chunked estimation + block sparse attention.
|
|
||||||
|
|
||||||
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
supports_prefill = True
|
|
||||||
supports_decode = False # XAttention is prefill-only
|
|
||||||
requires_block_selection = False # Only affects attention computation
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stride: int = 8,
|
|
||||||
threshold: float = 0.9,
|
|
||||||
chunk_size: Optional[int] = None,
|
|
||||||
use_triton: bool = True,
|
|
||||||
keep_sink: bool = False,
|
|
||||||
keep_recent: bool = False,
|
|
||||||
norm: float = 1.0,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize XAttention policy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stride: Stride for reorganizing Q/K (default: 8)
|
|
||||||
threshold: Block selection threshold, 0-1 (default: 0.9)
|
|
||||||
chunk_size: Chunk size for estimation (auto if None)
|
|
||||||
use_triton: Use Triton kernels (requires SM 80+)
|
|
||||||
keep_sink: Always keep first block (sink tokens)
|
|
||||||
keep_recent: Always keep recent diagonal blocks
|
|
||||||
norm: Normalization factor for attention scores
|
|
||||||
"""
|
|
||||||
self.stride = stride
|
|
||||||
self.threshold = threshold
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
self.use_triton = use_triton
|
|
||||||
self.keep_sink = keep_sink
|
|
||||||
self.keep_recent = keep_recent
|
|
||||||
self.norm = norm
|
|
||||||
|
|
||||||
# Check Triton availability
|
|
||||||
if self.use_triton:
|
|
||||||
try:
|
|
||||||
import triton
|
|
||||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
||||||
if props.major < 8:
|
|
||||||
self.use_triton = False
|
|
||||||
print(f"XAttention: Triton requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.")
|
|
||||||
except ImportError:
|
|
||||||
self.use_triton = False
|
|
||||||
print("XAttention: Triton not available. Falling back to PyTorch.")
|
|
||||||
|
|
||||||
def select_blocks(
|
|
||||||
self,
|
|
||||||
available_blocks: List[int],
|
|
||||||
ctx: PolicyContext,
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
|
||||||
Select blocks for decode phase.
|
|
||||||
|
|
||||||
XAttention is prefill-only, so this method is only used as a fallback.
|
|
||||||
Returns all available blocks by default.
|
|
||||||
"""
|
|
||||||
# XAttention is prefill-only, but we need to implement this abstract method
|
|
||||||
# Since requires_block_selection=False, this won't be called for loading
|
|
||||||
return available_blocks
|
|
||||||
|
|
||||||
def sparse_prefill_attention(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute XAttention sparse attention for prefill.
|
|
||||||
|
|
||||||
For CPU offload mode, uses FlashAttention directly with native GQA support.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Current transformer layer index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
seq_len = q.shape[0]
|
|
||||||
num_heads = q.shape[1]
|
|
||||||
head_dim = q.shape[2]
|
|
||||||
num_kv_heads = k.shape[1]
|
|
||||||
|
|
||||||
# Use FlashAttention directly for CPU offload mode
|
|
||||||
# FlashAttention supports GQA natively
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
|
|
||||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=seq_len,
|
|
||||||
max_seqlen_k=seq_len,
|
|
||||||
softmax_scale=1.0 / math.sqrt(head_dim),
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Fallback: PyTorch SDPA (supports GQA natively)
|
|
||||||
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
|
||||||
q, k, v,
|
|
||||||
attn_mask=None,
|
|
||||||
is_causal=True,
|
|
||||||
scale=1.0 / math.sqrt(head_dim)
|
|
||||||
)
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset policy state (no state to reset for XAttention)."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (f"XAttentionPolicy("
|
|
||||||
f"stride={self.stride}, "
|
|
||||||
f"threshold={self.threshold}, "
|
|
||||||
f"use_triton={self.use_triton})")
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5.5 框架集成
|
|
||||||
|
|
||||||
**config.py - 添加配置参数**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SparsePolicyType(Enum):
|
|
||||||
"""Sparse attention policy types."""
|
|
||||||
FULL = auto()
|
|
||||||
QUEST = auto()
|
|
||||||
MINFERENCE = auto()
|
|
||||||
XATTN = auto() # 新增
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Config:
|
|
||||||
# ... 其他配置
|
|
||||||
|
|
||||||
# XAttention configuration
|
|
||||||
xattn_stride: int = 8
|
|
||||||
xattn_threshold: float = 0.9
|
|
||||||
xattn_chunk_size: int = 16384
|
|
||||||
xattn_use_triton: bool = True
|
|
||||||
xattn_keep_sink: bool = False
|
|
||||||
xattn_keep_recent: bool = False
|
|
||||||
xattn_norm: float = 1.0
|
|
||||||
```
|
|
||||||
|
|
||||||
**__init__.py - 注册策略**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
|
||||||
if policy_type == SparsePolicyType.XATTN:
|
|
||||||
return XAttentionPolicy(
|
|
||||||
stride=kwargs.get("stride", 8),
|
|
||||||
threshold=kwargs.get("threshold", 0.9),
|
|
||||||
chunk_size=kwargs.get("chunk_size", 16384),
|
|
||||||
use_triton=kwargs.get("use_triton", True),
|
|
||||||
keep_sink=kwargs.get("keep_sink", False),
|
|
||||||
keep_recent=kwargs.get("keep_recent", False),
|
|
||||||
norm=kwargs.get("norm", 1.0),
|
|
||||||
)
|
|
||||||
# ... 其他策略
|
|
||||||
```
|
|
||||||
|
|
||||||
**model_runner.py - 使用策略**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 在 SparsePolicy 初始化时自动选择
|
|
||||||
if self.config.sparse_policy == SparsePolicyType.XATTN:
|
|
||||||
self.sparse_prefill_policy = XAttentionPolicy(...)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. 问题与解决方案
|
|
||||||
|
|
||||||
### 6.1 问题 1: Abstract Method Not Implemented
|
|
||||||
|
|
||||||
**错误**:
|
|
||||||
```python
|
|
||||||
TypeError: Can't instantiate abstract class XAttentionPolicy
|
|
||||||
with abstract method select_blocks
|
|
||||||
```
|
|
||||||
|
|
||||||
**原因**:
|
|
||||||
- `SparsePolicy` 是抽象基类,要求子类实现 `select_blocks()`
|
|
||||||
- XAttention 是 prefill-only 策略,不需要 block selection
|
|
||||||
|
|
||||||
**解决**:
|
|
||||||
```python
|
|
||||||
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]:
|
|
||||||
"""
|
|
||||||
Select blocks for decode phase.
|
|
||||||
|
|
||||||
XAttention is prefill-only, so this method is only used as a fallback.
|
|
||||||
Returns all available blocks by default.
|
|
||||||
"""
|
|
||||||
# Since requires_block_selection=False, this won't be called for loading
|
|
||||||
return available_blocks
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.2 问题 2: CUDA OOM During Estimation
|
|
||||||
|
|
||||||
**错误**:
|
|
||||||
```
|
|
||||||
CUDA out of memory. Tried to allocate 1013.92 GiB
|
|
||||||
```
|
|
||||||
|
|
||||||
**原因**:
|
|
||||||
- `_xattn_estimate()` 使用 `q_len` 计算 `k_block_num`
|
|
||||||
- 但在 chunked prefill 中,`q_len` 是当前 chunk 大小(2048)
|
|
||||||
- 而不是完整上下文长度(32768)
|
|
||||||
- 导致 padding 计算错误
|
|
||||||
|
|
||||||
**原始代码问题**:
|
|
||||||
```python
|
|
||||||
batch_size, num_heads, k_len, head_dim = key_states.shape
|
|
||||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
|
||||||
|
|
||||||
# 错误:使用 q_len 计算 k_block_num
|
|
||||||
k_block_num = (k_len + k_num_to_pad) // block_size # 应该用完整 k_len
|
|
||||||
```
|
|
||||||
|
|
||||||
**解决**:
|
|
||||||
简化实现,直接使用 FlashAttention:
|
|
||||||
```python
|
|
||||||
def sparse_prefill_attention(self, q, k, v, layer_id):
|
|
||||||
# 使用 FlashAttention 直接计算
|
|
||||||
# 不进行 chunked estimation(与 offload 架构不兼容)
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
...
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.3 问题 3: GQA Head Count Mismatch
|
|
||||||
|
|
||||||
**错误**:
|
|
||||||
```
|
|
||||||
ValueError: Number of heads in key/value must divide number of heads in query
|
|
||||||
```
|
|
||||||
|
|
||||||
**原因**:
|
|
||||||
- Llama-3.1-8B 使用 GQA:num_heads=32, num_kv_heads=8
|
|
||||||
- 原始 XAttention 代码手动展开 KV heads:
|
|
||||||
```python
|
|
||||||
# 错误方式
|
|
||||||
if num_kv_heads != num_heads:
|
|
||||||
key_states = key_states.repeat_interleave(num_heads // num_kv_heads, dim=1)
|
|
||||||
```
|
|
||||||
|
|
||||||
**解决**:
|
|
||||||
依赖 FlashAttention 的原生 GQA 支持:
|
|
||||||
```python
|
|
||||||
# FlashAttention 自动处理 GQA,无需手动展开
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
q, k, v, # k, v 可以有更少的 heads
|
|
||||||
...
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.4 Bug Fix: kernels.py Line 106
|
|
||||||
|
|
||||||
**原始代码**:
|
|
||||||
```python
|
|
||||||
for iter in range(num_iters_before_causal + 1, num_iters):
|
|
||||||
X = torch.zeros([segment_size // block_size], dtype=torch.float32) # 错误
|
|
||||||
```
|
|
||||||
|
|
||||||
**修复**:
|
|
||||||
```python
|
|
||||||
for iter in range(num_iters_before_causal + 1, num_iters):
|
|
||||||
X = tl.zeros([segment_size // block_size], dtype=torch.float32) # 正确
|
|
||||||
```
|
|
||||||
|
|
||||||
**原因**:
|
|
||||||
- Triton JIT kernel 中必须使用 `tl.zeros` 而不是 `torch.zeros`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 7. 测试验证
|
|
||||||
|
|
||||||
### 7.1 测试环境
|
|
||||||
|
|
||||||
- **模型**: Llama-3.1-8B-Instruct
|
|
||||||
- **GPU**: RTX 3090 (24GB)
|
|
||||||
- **数据集**: RULER 32k benchmark
|
|
||||||
- **模式**: CPU offload enabled
|
|
||||||
|
|
||||||
### 7.2 测试命令
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# NIAH 任务测试
|
|
||||||
CUDA_VISIBLE_DEVICES=4 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
|
||||||
python tests/test_ruler.py \
|
|
||||||
--data-dir tests/data/ruler_32k \
|
|
||||||
--enable-offload \
|
|
||||||
--sparse-policy XATTN \
|
|
||||||
--num-samples 3 \
|
|
||||||
--datasets niah_single_1,niah_multikey_1,niah_multiquery,niah_multivalue \
|
|
||||||
--max-model-len 32896
|
|
||||||
|
|
||||||
# QA/Recall 任务测试(并行运行)
|
|
||||||
CUDA_VISIBLE_DEVICES=5 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
|
||||||
python tests/test_ruler.py \
|
|
||||||
--data-dir tests/data/ruler_32k \
|
|
||||||
--enable-offload \
|
|
||||||
--sparse-policy XATTN \
|
|
||||||
--num-samples 3 \
|
|
||||||
--datasets qa_1,qa_2,vt,cwe,fwe \
|
|
||||||
--max-model-len 32896
|
|
||||||
```
|
|
||||||
|
|
||||||
### 7.3 测试结果
|
|
||||||
|
|
||||||
#### GPU 4 - NIAH 任务
|
|
||||||
|
|
||||||
| 任务 | 通过/总数 | 准确率 | 平均分 |
|
|
||||||
|------|----------|--------|--------|
|
|
||||||
| niah_single_1 | 3/3 | 100.0% | 1.000 |
|
|
||||||
| niah_multikey_1 | 3/3 | 100.0% | 1.000 |
|
|
||||||
| niah_multiquery | 3/3 | 100.0% | 1.000 |
|
|
||||||
| niah_multivalue | 3/3 | 100.0% | 1.000 |
|
|
||||||
| **NIAH 总计** | **12/12** | **100.0%** | **1.000** |
|
|
||||||
|
|
||||||
#### GPU 5 - QA/Recall 任务
|
|
||||||
|
|
||||||
| 任务 | 通过/总数 | 准确率 | 平均分 |
|
|
||||||
|------|----------|--------|--------|
|
|
||||||
| qa_1 | 2/3 | 66.7% | 0.667 |
|
|
||||||
| qa_2 | 1/3 | 33.3% | 0.333 |
|
|
||||||
| vt | 3/3 | 100.0% | 0.867 |
|
|
||||||
| cwe | 2/3 | 66.7% | 0.467 |
|
|
||||||
| fwe | 3/3 | 100.0% | 0.889 |
|
|
||||||
| **QA/Recall 总计** | **11/15** | **73.3%** | **0.644** |
|
|
||||||
|
|
||||||
#### 总体结果
|
|
||||||
|
|
||||||
- **总计**: 23/27 样本通过 (85.2% 准确率)
|
|
||||||
- **耗时**: GPU 4 (74.9s), GPU 5 (425.1s)
|
|
||||||
- **结论**: XAttention 集成成功,test_ruler.py 全部通过 ✅
|
|
||||||
|
|
||||||
### 7.4 内存使用
|
|
||||||
|
|
||||||
```
|
|
||||||
OffloadEngine initialized: GPU=650.0MB, CPU=4224.0MB
|
|
||||||
Ring buffer GPU cache: 522.0 MB (4 buffers × 33408 tokens)
|
|
||||||
CPU cache: 4224.0 MB (32 layers × 33 blocks)
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 8. 使用指南
|
|
||||||
|
|
||||||
### 8.1 基本用法
|
|
||||||
|
|
||||||
```python
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from nanovllm.config import SparsePolicyType
|
|
||||||
|
|
||||||
llm = LLM(
|
|
||||||
model_path="/path/to/model",
|
|
||||||
enable_cpu_offload=True,
|
|
||||||
sparse_policy=SparsePolicyType.XATTN,
|
|
||||||
xattn_threshold=0.9,
|
|
||||||
xattn_stride=8,
|
|
||||||
)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(temperature=0.1, max_tokens=128)
|
|
||||||
outputs = llm.generate(["Your prompt here"], sampling_params)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 8.2 命令行测试
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# RULER benchmark
|
|
||||||
python tests/test_ruler.py \
|
|
||||||
--model ~/models/Llama-3.1-8B-Instruct \
|
|
||||||
--data-dir tests/data/ruler_32k \
|
|
||||||
--enable-offload \
|
|
||||||
--sparse-policy XATTN \
|
|
||||||
--max-model-len 32896
|
|
||||||
|
|
||||||
# 单个样本测试
|
|
||||||
python tests/test_needle.py \
|
|
||||||
--model ~/models/Llama-3.1-8B-Instruct \
|
|
||||||
--enable-offload \
|
|
||||||
--sparse-policy XATTN
|
|
||||||
```
|
|
||||||
|
|
||||||
### 8.3 配置参数
|
|
||||||
|
|
||||||
| 参数 | 默认值 | 说明 |
|
|
||||||
|------|--------|------|
|
|
||||||
| `sparse_policy` | `FULL` | 稀疏策略类型 (FULL, QUEST, MINFERENCE, XATTN) |
|
|
||||||
| `xattn_threshold` | 0.9 | Block 选择阈值 (0-1) |
|
|
||||||
| `xattn_stride` | 8 | Q/K 重组步长 |
|
|
||||||
| `xattn_chunk_size` | 16384 | Estimation chunk 大小 |
|
|
||||||
| `xattn_use_triton` | True | 是否使用 Triton kernels |
|
|
||||||
|
|
||||||
### 8.4 与其他策略对比
|
|
||||||
|
|
||||||
| 策略 | 阶段 | 用途 | 优势 |
|
|
||||||
|------|------|------|------|
|
|
||||||
| FULL | prefill + decode | 基线 | 准确率最高 |
|
|
||||||
| QUEST | decode only | Top-K block selection | 适合 decode 优化 |
|
|
||||||
| MINFERENCE | prefill | Vertical + Slash pattern | GPU-only 高效 |
|
|
||||||
| XATTN | prefill only | Chunked estimation + block sparse | 长上下文 prefill |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 附录
|
|
||||||
|
|
||||||
### A. 相关文档
|
|
||||||
|
|
||||||
- [`sparse_attention_guide.md`](sparse_attention_guide.md) - 稀疏注意力方法概述
|
|
||||||
- [`sparse_offload_integration.md`](sparse_offload_integration.md) - 稀疏策略与 offload 集成
|
|
||||||
- [`block_sparse_attention_lib.md`](block_sparse_attention_lib.md) - Block-Sparse-Attention 库参考
|
|
||||||
|
|
||||||
### B. Git 历史
|
|
||||||
|
|
||||||
- `ac1ccbc` - feat: add XAttention sparse policy integration
|
|
||||||
- `57f4e9c` - docs: reorganize documentation files
|
|
||||||
|
|
||||||
### C. 待办事项
|
|
||||||
|
|
||||||
- [ ] GPU-only 模式下的完整 XAttention 实现(使用 Triton kernels)
|
|
||||||
- [ ] 性能基准测试(与 FULL、MINFERENCE 对比)
|
|
||||||
- [ ] 自适应 threshold 调整
|
|
||||||
- [ ] 更多上下文长度测试(64k, 128k)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
**作者**: Zijie Tian
|
|
||||||
**日期**: 2026-01-14
|
|
||||||
**版本**: 1.0
|
|
||||||
160
findings.md
Normal file
160
findings.md
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
# Findings: Multi-Model Support Analysis
|
||||||
|
|
||||||
|
## Current Architecture Analysis
|
||||||
|
|
||||||
|
### Model Loading Flow
|
||||||
|
```
|
||||||
|
LLM(model_path)
|
||||||
|
→ LLMEngine.__init__()
|
||||||
|
→ Config.__post_init__()
|
||||||
|
→ hf_config = AutoConfig.from_pretrained(model)
|
||||||
|
→ ModelRunner.__init__()
|
||||||
|
→ model = Qwen3ForCausalLM(hf_config) ← HARDCODED
|
||||||
|
→ load_model(model, config.model)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Key Files
|
||||||
|
| File | Purpose |
|
||||||
|
|------|---------|
|
||||||
|
| `nanovllm/engine/model_runner.py` | 模型加载和运行 |
|
||||||
|
| `nanovllm/models/qwen3.py` | Qwen3 模型定义 |
|
||||||
|
| `nanovllm/utils/loader.py` | safetensors 权重加载 |
|
||||||
|
| `nanovllm/layers/rotary_embedding.py` | RoPE 实现 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Llama 3.1 Config Analysis
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"architectures": ["LlamaForCausalLM"],
|
||||||
|
"model_type": "llama",
|
||||||
|
"attention_bias": false,
|
||||||
|
"mlp_bias": false,
|
||||||
|
"head_dim": 128,
|
||||||
|
"hidden_size": 4096,
|
||||||
|
"intermediate_size": 14336,
|
||||||
|
"num_attention_heads": 32,
|
||||||
|
"num_hidden_layers": 32,
|
||||||
|
"num_key_value_heads": 8,
|
||||||
|
"hidden_act": "silu",
|
||||||
|
"rms_norm_eps": 1e-05,
|
||||||
|
"rope_theta": 500000.0,
|
||||||
|
"rope_scaling": {
|
||||||
|
"factor": 8.0,
|
||||||
|
"high_freq_factor": 4.0,
|
||||||
|
"low_freq_factor": 1.0,
|
||||||
|
"original_max_position_embeddings": 8192,
|
||||||
|
"rope_type": "llama3"
|
||||||
|
},
|
||||||
|
"max_position_embeddings": 131072,
|
||||||
|
"tie_word_embeddings": false,
|
||||||
|
"vocab_size": 128256
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llama 3 RoPE Scaling
|
||||||
|
Llama 3 使用特殊的 RoPE scaling 策略 (`rope_type: "llama3"`):
|
||||||
|
- 低频分量保持不变(对应短距离依赖)
|
||||||
|
- 高频分量线性插值(对应长距离依赖)
|
||||||
|
- 参数: `factor`, `low_freq_factor`, `high_freq_factor`, `original_max_position_embeddings`
|
||||||
|
|
||||||
|
参考实现 (transformers):
|
||||||
|
```python
|
||||||
|
def _compute_llama3_parameters(config, device, inv_freq):
|
||||||
|
factor = config.factor
|
||||||
|
low_freq_factor = config.low_freq_factor
|
||||||
|
high_freq_factor = config.high_freq_factor
|
||||||
|
old_context_len = config.original_max_position_embeddings
|
||||||
|
|
||||||
|
low_freq_wavelen = old_context_len / low_freq_factor
|
||||||
|
high_freq_wavelen = old_context_len / high_freq_factor
|
||||||
|
|
||||||
|
wavelen = 2 * math.pi / inv_freq
|
||||||
|
inv_freq_llama = torch.where(
|
||||||
|
wavelen > low_freq_wavelen,
|
||||||
|
inv_freq / factor,
|
||||||
|
inv_freq
|
||||||
|
)
|
||||||
|
smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||||
|
smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama + smooth_factor * inv_freq
|
||||||
|
is_medium_freq = (wavelen >= high_freq_wavelen) & (wavelen <= low_freq_wavelen)
|
||||||
|
inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
|
||||||
|
return inv_freq_llama
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Weight Mapping Analysis
|
||||||
|
|
||||||
|
### Qwen3 packed_modules_mapping
|
||||||
|
```python
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"q_proj": ("qkv_proj", "q"),
|
||||||
|
"k_proj": ("qkv_proj", "k"),
|
||||||
|
"v_proj": ("qkv_proj", "v"),
|
||||||
|
"gate_proj": ("gate_up_proj", 0),
|
||||||
|
"up_proj": ("gate_up_proj", 1),
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llama Weight Names (from safetensors)
|
||||||
|
预期 Llama 权重命名与 Qwen3 类似:
|
||||||
|
- `model.layers.{i}.self_attn.q_proj.weight`
|
||||||
|
- `model.layers.{i}.self_attn.k_proj.weight`
|
||||||
|
- `model.layers.{i}.self_attn.v_proj.weight`
|
||||||
|
- `model.layers.{i}.self_attn.o_proj.weight`
|
||||||
|
- `model.layers.{i}.mlp.gate_proj.weight`
|
||||||
|
- `model.layers.{i}.mlp.up_proj.weight`
|
||||||
|
- `model.layers.{i}.mlp.down_proj.weight`
|
||||||
|
- `model.layers.{i}.input_layernorm.weight`
|
||||||
|
- `model.layers.{i}.post_attention_layernorm.weight`
|
||||||
|
|
||||||
|
**结论**: Llama 的 `packed_modules_mapping` 与 Qwen3 相同,可以复用。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Shared Components (Can Reuse)
|
||||||
|
|
||||||
|
| Component | File | Notes |
|
||||||
|
|-----------|------|-------|
|
||||||
|
| `RMSNorm` | `layers/layernorm.py` | 通用 |
|
||||||
|
| `SiluAndMul` | `layers/activation.py` | 通用 |
|
||||||
|
| `Attention` | `layers/attention.py` | FlashAttention wrapper |
|
||||||
|
| `QKVParallelLinear` | `layers/linear.py` | 支持 bias=False |
|
||||||
|
| `RowParallelLinear` | `layers/linear.py` | 通用 |
|
||||||
|
| `MergedColumnParallelLinear` | `layers/linear.py` | 通用 |
|
||||||
|
| `VocabParallelEmbedding` | `layers/embed_head.py` | 通用 |
|
||||||
|
| `ParallelLMHead` | `layers/embed_head.py` | 通用 |
|
||||||
|
| `load_model` | `utils/loader.py` | 通用 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Llama vs Qwen3 Implementation Diff
|
||||||
|
|
||||||
|
### Attention
|
||||||
|
| Feature | Qwen3Attention | LlamaAttention |
|
||||||
|
|---------|----------------|----------------|
|
||||||
|
| QKV bias | 可配置 (attention_bias) | 始终 False |
|
||||||
|
| q_norm | 有 (when bias=False) | 无 |
|
||||||
|
| k_norm | 有 (when bias=False) | 无 |
|
||||||
|
| RoPE | Standard | Llama3 scaled |
|
||||||
|
|
||||||
|
### MLP
|
||||||
|
| Feature | Qwen3MLP | LlamaMLP |
|
||||||
|
|---------|----------|----------|
|
||||||
|
| gate/up bias | False | False |
|
||||||
|
| down bias | False | False |
|
||||||
|
| hidden_act | silu | silu |
|
||||||
|
|
||||||
|
**结论**: Llama MLP 与 Qwen3 MLP 几乎相同,可以直接复用或简化。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Risk Assessment
|
||||||
|
|
||||||
|
| Risk | Impact | Mitigation |
|
||||||
|
|------|--------|------------|
|
||||||
|
| RoPE 实现错误 | 高 - 导致错误输出 | 参考 transformers 实现,单元测试 |
|
||||||
|
| 权重映射错误 | 高 - 模型无法加载 | 检查 safetensors 键名 |
|
||||||
|
| 注册表循环导入 | 中 - 启动失败 | 延迟导入 |
|
||||||
@@ -9,8 +9,6 @@ class SparsePolicyType(Enum):
|
|||||||
"""Sparse attention policy types."""
|
"""Sparse attention policy types."""
|
||||||
FULL = auto() # No sparse attention (load all blocks)
|
FULL = auto() # No sparse attention (load all blocks)
|
||||||
QUEST = auto() # Query-aware Top-K block selection (decode only)
|
QUEST = auto() # Query-aware Top-K block selection (decode only)
|
||||||
MINFERENCE = auto() # MInference vertical + slash sparse prefill (GPU-only)
|
|
||||||
XATTN = auto() # XAttention chunked estimation + block-sparse attention
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -33,7 +31,6 @@ class Config:
|
|||||||
offload_policy: str = "lru" # "lru", "fifo", or full class path
|
offload_policy: str = "lru" # "lru", "fifo", or full class path
|
||||||
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
|
num_transfer_streams: int = 4 # Number of CUDA streams for async transfers
|
||||||
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
|
num_gpu_blocks: int = -1 # User-specified GPU blocks count, -1 = auto (use max available)
|
||||||
num_kv_buffers: int = 4 # Ring buffer size for layer-wise offload (decode H2D pipeline)
|
|
||||||
|
|
||||||
# Computed fields for offload (set in __post_init__ or by ModelRunner)
|
# Computed fields for offload (set in __post_init__ or by ModelRunner)
|
||||||
num_gpu_kvcache_blocks: int = -1
|
num_gpu_kvcache_blocks: int = -1
|
||||||
@@ -42,27 +39,10 @@ class Config:
|
|||||||
# Sparse attention configuration
|
# Sparse attention configuration
|
||||||
# Quest: decode-only sparse attention with Top-K block selection
|
# Quest: decode-only sparse attention with Top-K block selection
|
||||||
# FULL: no sparse attention (load all blocks)
|
# FULL: no sparse attention (load all blocks)
|
||||||
# MINFERENCE: MInference vertical + slash sparse prefill (GPU-only)
|
|
||||||
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
|
sparse_policy: SparsePolicyType = SparsePolicyType.FULL
|
||||||
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
sparse_topk_blocks: int = 8 # Top-K blocks for Quest
|
||||||
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
|
sparse_threshold_blocks: int = 4 # Apply sparse only when blocks > threshold
|
||||||
|
|
||||||
# MInference configuration (used when sparse_policy == MINFERENCE)
|
|
||||||
minference_adaptive_budget: float = 0.3 # Budget as fraction of seq_len (None to use fixed sizes)
|
|
||||||
minference_vertical_size: int = 1000 # Fixed vertical size (if adaptive_budget is None)
|
|
||||||
minference_slash_size: int = 6096 # Fixed slash size (if adaptive_budget is None)
|
|
||||||
minference_num_sink_tokens: int = 30 # Sink tokens to always keep
|
|
||||||
minference_num_recent_diags: int = 100 # Recent diagonals to always keep
|
|
||||||
|
|
||||||
# XAttention configuration (used when sparse_policy == XATTN)
|
|
||||||
xattn_stride: int = 8 # Stride for reorganizing Q/K
|
|
||||||
xattn_threshold: float = 0.9 # Block selection threshold (0-1)
|
|
||||||
xattn_chunk_size: int = 16384 # Chunk size for estimation (auto if None)
|
|
||||||
xattn_use_triton: bool = True # Use Triton kernels (requires SM 80+)
|
|
||||||
xattn_keep_sink: bool = False # Always keep first block (sink tokens)
|
|
||||||
xattn_keep_recent: bool = False # Always keep recent diagonal blocks
|
|
||||||
xattn_norm: float = 1.0 # Normalization factor for attention scores
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert os.path.isdir(self.model)
|
assert os.path.isdir(self.model)
|
||||||
assert self.kvcache_block_size % 256 == 0
|
assert self.kvcache_block_size % 256 == 0
|
||||||
@@ -71,15 +51,6 @@ class Config:
|
|||||||
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings)
|
||||||
assert self.max_num_batched_tokens >= self.max_model_len
|
assert self.max_num_batched_tokens >= self.max_model_len
|
||||||
|
|
||||||
# CPU offload mode only supports single sequence (layer-wise processing)
|
|
||||||
if self.enable_cpu_offload and self.max_num_seqs != 1:
|
|
||||||
import logging
|
|
||||||
logging.warning(
|
|
||||||
f"CPU offload mode only supports single sequence. "
|
|
||||||
f"Overriding max_num_seqs from {self.max_num_seqs} to 1."
|
|
||||||
)
|
|
||||||
self.max_num_seqs = 1
|
|
||||||
|
|
||||||
# Override torch_dtype if user specified
|
# Override torch_dtype if user specified
|
||||||
if self.dtype is not None:
|
if self.dtype is not None:
|
||||||
dtype_map = {
|
dtype_map = {
|
||||||
|
|||||||
@@ -34,56 +34,14 @@ class LLMEngine:
|
|||||||
# Set Sequence.block_size to match the KV cache block size
|
# Set Sequence.block_size to match the KV cache block size
|
||||||
Sequence.block_size = config.kvcache_block_size
|
Sequence.block_size = config.kvcache_block_size
|
||||||
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
|
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
|
||||||
self._closed = False
|
atexit.register(self.exit)
|
||||||
atexit.register(self._atexit_handler)
|
|
||||||
|
|
||||||
def _atexit_handler(self):
|
def exit(self):
|
||||||
"""Handler for atexit - only runs if close() wasn't called."""
|
|
||||||
if not self._closed:
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"""Explicitly close the engine and release all resources.
|
|
||||||
|
|
||||||
This method is idempotent - calling it multiple times is safe.
|
|
||||||
Supports: explicit close(), context manager, and __del__ fallback.
|
|
||||||
"""
|
|
||||||
if self._closed:
|
|
||||||
return
|
|
||||||
self._closed = True
|
|
||||||
|
|
||||||
# Unregister atexit to prevent double cleanup
|
|
||||||
try:
|
|
||||||
atexit.unregister(self._atexit_handler)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Cleanup resources
|
|
||||||
self.model_runner.call("exit")
|
self.model_runner.call("exit")
|
||||||
del self.model_runner
|
del self.model_runner
|
||||||
for p in self.ps:
|
for p in self.ps:
|
||||||
p.join()
|
p.join()
|
||||||
|
|
||||||
def exit(self):
|
|
||||||
"""Alias for close() - kept for backward compatibility."""
|
|
||||||
self.close()
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
"""Destructor - attempt cleanup if not already done."""
|
|
||||||
try:
|
|
||||||
self.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
"""Context manager entry."""
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
"""Context manager exit - ensures cleanup."""
|
|
||||||
self.close()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
|
def add_request(self, prompt: str | list[int], sampling_params: SamplingParams):
|
||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
prompt = self.tokenizer.encode(prompt)
|
prompt = self.tokenizer.encode(prompt)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -36,11 +36,10 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
KVCacheManager instance
|
KVCacheManager instance
|
||||||
"""
|
"""
|
||||||
if not getattr(config, 'enable_cpu_offload', False):
|
if not getattr(config, 'enable_cpu_offload', False):
|
||||||
# Default: pure GPU mode with contiguous cache for single-seq optimization
|
# Default: pure GPU mode
|
||||||
return GPUOnlyManager(
|
return GPUOnlyManager(
|
||||||
num_blocks=config.num_kvcache_blocks,
|
num_blocks=config.num_kvcache_blocks,
|
||||||
block_size=config.kvcache_block_size,
|
block_size=config.kvcache_block_size,
|
||||||
max_seq_len=config.max_model_len, # Enable contiguous cache
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# CPU offload is enabled
|
# CPU offload is enabled
|
||||||
@@ -71,20 +70,12 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
|
threshold_blocks=getattr(config, 'sparse_threshold_blocks', 4),
|
||||||
)
|
)
|
||||||
|
|
||||||
# max_seq_len needs to be larger than max_model_len to accommodate decode tokens
|
|
||||||
# When prefill uses ~max_model_len tokens, decode needs additional slots
|
|
||||||
# Add max_new_tokens (default 512) buffer for decode phase
|
|
||||||
max_new_tokens = getattr(config, 'max_new_tokens', 512)
|
|
||||||
max_seq_len = config.max_model_len + max_new_tokens
|
|
||||||
|
|
||||||
return HybridKVCacheManager(
|
return HybridKVCacheManager(
|
||||||
num_gpu_slots=num_gpu_blocks,
|
num_gpu_slots=num_gpu_blocks,
|
||||||
num_cpu_blocks=num_cpu_blocks,
|
num_cpu_blocks=num_cpu_blocks,
|
||||||
block_size=config.kvcache_block_size,
|
block_size=config.kvcache_block_size,
|
||||||
policy=eviction_policy,
|
policy=eviction_policy,
|
||||||
sparse_policy=sparse_policy,
|
sparse_policy=sparse_policy,
|
||||||
num_kv_buffers=getattr(config, 'num_kv_buffers', 4),
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -45,24 +45,21 @@ class GPUOnlyManager(KVCacheManager):
|
|||||||
- Paged attention with configurable block size
|
- Paged attention with configurable block size
|
||||||
- Prefix caching via xxhash
|
- Prefix caching via xxhash
|
||||||
- Reference counting for block sharing
|
- Reference counting for block sharing
|
||||||
- Contiguous cache for single-sequence layer-wise prefill (optional)
|
|
||||||
|
|
||||||
This manager is fully compatible with CUDA graphs since
|
This manager is fully compatible with CUDA graphs since
|
||||||
all data stays on GPU at fixed addresses.
|
all data stays on GPU at fixed addresses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_blocks: int, block_size: int, max_seq_len: int = 0):
|
def __init__(self, num_blocks: int, block_size: int):
|
||||||
"""
|
"""
|
||||||
Initialize GPU-only manager.
|
Initialize GPU-only manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_blocks: Total number of blocks to manage
|
num_blocks: Total number of blocks to manage
|
||||||
block_size: Tokens per block (default 256)
|
block_size: Tokens per block (default 256)
|
||||||
max_seq_len: Max sequence length for contiguous cache (0 to disable)
|
|
||||||
"""
|
"""
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self._num_blocks = num_blocks
|
self._num_blocks = num_blocks
|
||||||
self._max_seq_len = max_seq_len
|
|
||||||
|
|
||||||
# Block metadata
|
# Block metadata
|
||||||
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
|
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
|
||||||
@@ -80,11 +77,6 @@ class GPUOnlyManager(KVCacheManager):
|
|||||||
self.num_kv_heads: int = 0
|
self.num_kv_heads: int = 0
|
||||||
self.head_dim: int = 0
|
self.head_dim: int = 0
|
||||||
|
|
||||||
# Contiguous cache for single-seq layer-wise prefill (set by allocate_cache)
|
|
||||||
self.contiguous_k_cache: Optional[Tensor] = None
|
|
||||||
self.contiguous_v_cache: Optional[Tensor] = None
|
|
||||||
self.contiguous_seq_len: int = 0 # Current sequence length in contiguous cache
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def block_size(self) -> int:
|
def block_size(self) -> int:
|
||||||
return self._block_size
|
return self._block_size
|
||||||
@@ -113,23 +105,6 @@ class GPUOnlyManager(KVCacheManager):
|
|||||||
dtype=dtype, device="cuda"
|
dtype=dtype, device="cuda"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Allocate contiguous cache for single-seq layer-wise prefill
|
|
||||||
# Only allocate if there's enough free memory (at least 2GB margin)
|
|
||||||
if self._max_seq_len > 0:
|
|
||||||
contiguous_cache_bytes = 2 * num_layers * self._max_seq_len * num_kv_heads * head_dim * dtype.itemsize
|
|
||||||
free_memory = torch.cuda.mem_get_info()[0]
|
|
||||||
|
|
||||||
if free_memory > contiguous_cache_bytes + 2 * 1024**3: # 2GB margin
|
|
||||||
# Shape: [num_layers, max_seq_len, kv_heads, head_dim]
|
|
||||||
self.contiguous_k_cache = torch.empty(
|
|
||||||
num_layers, self._max_seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
)
|
|
||||||
self.contiguous_v_cache = torch.empty(
|
|
||||||
num_layers, self._max_seq_len, num_kv_heads, head_dim,
|
|
||||||
dtype=dtype, device="cuda"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
"""Get K/V cache for a layer."""
|
"""Get K/V cache for a layer."""
|
||||||
assert self.kv_cache is not None, "Cache not allocated"
|
assert self.kv_cache is not None, "Cache not allocated"
|
||||||
|
|||||||
@@ -65,22 +65,23 @@ class LogicalBlock:
|
|||||||
|
|
||||||
class HybridKVCacheManager(KVCacheManager):
|
class HybridKVCacheManager(KVCacheManager):
|
||||||
"""
|
"""
|
||||||
Hybrid CPU-GPU KV cache manager with layer-wise offload design.
|
Hybrid CPU-GPU KV cache manager with ring buffer design.
|
||||||
|
|
||||||
Architecture (CPU-primary mode):
|
Architecture (CPU-primary mode):
|
||||||
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
|
- CPU pool: Primary storage for all KV cache (num_cpu_blocks)
|
||||||
- GPU ring buffer: For decode H2D pipeline (num_kv_buffers)
|
- GPU buffer: Ring buffer for computation only (num_gpu_slots)
|
||||||
- Decode buffer: Per-layer accumulation of decode tokens (block_size)
|
- Logical blocks: What sequences reference (num_cpu_blocks)
|
||||||
|
|
||||||
Design:
|
Design:
|
||||||
- All KV cache is stored on CPU as primary storage
|
- All KV cache is stored on CPU as primary storage
|
||||||
- GPU ring buffer enables pipelined H2D transfers during decode
|
- GPU is used as a ring buffer for computation only (no persistent data)
|
||||||
- During prefill: KV is computed and offloaded layer-by-layer to CPU
|
- During prefill: KV is written to GPU ring slot, then offloaded to CPU
|
||||||
- During decode: Previous KV is loaded from CPU via ring buffer pipeline
|
- During decode: Previous KV is loaded from CPU to GPU for attention
|
||||||
|
- Ring buffer enables pipelined H2D transfers overlapped with computation
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
- Logical blocks map 1:1 with CPU blocks (total_blocks = num_cpu_blocks)
|
- Logical blocks map 1:1 with CPU blocks (total_blocks = num_cpu_blocks)
|
||||||
- GPU ring buffer is for decode pipeline, not persistent storage
|
- GPU slots are transient compute buffers, not tracked in logical blocks
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -90,31 +91,25 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
block_size: int,
|
block_size: int,
|
||||||
policy: Optional[EvictionPolicy] = None,
|
policy: Optional[EvictionPolicy] = None,
|
||||||
sparse_policy: "SparsePolicy" = None,
|
sparse_policy: "SparsePolicy" = None,
|
||||||
num_kv_buffers: int = 4,
|
|
||||||
max_seq_len: int = 131072,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize hybrid manager with layer-wise offload design.
|
Initialize hybrid manager with CPU-primary ring buffer design.
|
||||||
|
|
||||||
All KV cache is stored on CPU as primary storage. GPU ring buffer is used
|
All KV cache is stored on CPU as primary storage. GPU slots are used
|
||||||
for decode H2D pipeline.
|
as a ring buffer for computation only.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_gpu_slots: Number of GPU buffer slots (kept for backward compat, not used)
|
num_gpu_slots: Number of GPU buffer slots (ring buffer for computation)
|
||||||
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
num_cpu_blocks: Number of CPU pool blocks (primary storage)
|
||||||
block_size: Tokens per block
|
block_size: Tokens per block
|
||||||
policy: Eviction policy (default: LRU, used for prefix cache management)
|
policy: Eviction policy (default: LRU, used for prefix cache management)
|
||||||
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
|
sparse_policy: Sparse attention policy (Quest for decode-only sparse)
|
||||||
num_kv_buffers: Ring buffer size for decode H2D pipeline
|
|
||||||
max_seq_len: Maximum sequence length for GPU buffer allocation
|
|
||||||
"""
|
"""
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self.num_gpu_slots = num_gpu_slots
|
self.num_gpu_slots = num_gpu_slots
|
||||||
self.num_cpu_blocks = num_cpu_blocks
|
self.num_cpu_blocks = num_cpu_blocks
|
||||||
self.num_kv_buffers = num_kv_buffers
|
|
||||||
self.max_seq_len = max_seq_len
|
|
||||||
# In CPU-primary mode, logical blocks map 1:1 with CPU blocks
|
# In CPU-primary mode, logical blocks map 1:1 with CPU blocks
|
||||||
# GPU ring buffer is for decode pipeline, not persistent storage
|
# GPU slots are transient compute buffers, not tracked as logical blocks
|
||||||
self.total_blocks = num_cpu_blocks
|
self.total_blocks = num_cpu_blocks
|
||||||
|
|
||||||
# Eviction policy
|
# Eviction policy
|
||||||
@@ -152,7 +147,7 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
# Track blocks pending GPU load (for decode graph)
|
# Track blocks pending GPU load (for decode graph)
|
||||||
self.pending_gpu_loads: Set[int] = set() # logical_ids
|
self.pending_gpu_loads: Set[int] = set() # logical_ids
|
||||||
|
|
||||||
# Track blocks that have been prefilled (KV offloaded to CPU)
|
# Track blocks that have been prefilled (KV written) for chunked prefill
|
||||||
self.prefilled_blocks: Set[int] = set() # logical_ids
|
self.prefilled_blocks: Set[int] = set() # logical_ids
|
||||||
|
|
||||||
# Track decode starting position within block (for batched offload optimization)
|
# Track decode starting position within block (for batched offload optimization)
|
||||||
@@ -187,21 +182,13 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
num_kv_buffers=self.num_kv_buffers,
|
|
||||||
max_seq_len=self.max_seq_len,
|
|
||||||
sparse_policy=self.sparse_policy,
|
sparse_policy=self.sparse_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
def get_layer_cache(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
"""
|
"""Get GPU K/V cache tensors for a layer."""
|
||||||
Get GPU K/V cache tensors for a layer.
|
|
||||||
|
|
||||||
Note: In layer-wise offload mode, this returns empty tensors as KV
|
|
||||||
is managed directly by the offload engine's ring buffer.
|
|
||||||
"""
|
|
||||||
assert self.offload_engine is not None
|
assert self.offload_engine is not None
|
||||||
# Return empty tensors - actual KV is in offload_engine's ring buffer
|
return self.offload_engine.get_layer_cache(layer_id)
|
||||||
return torch.empty(0), torch.empty(0)
|
|
||||||
|
|
||||||
def can_allocate(self, seq: Sequence) -> bool:
|
def can_allocate(self, seq: Sequence) -> bool:
|
||||||
"""Check if we can allocate blocks for a new sequence."""
|
"""Check if we can allocate blocks for a new sequence."""
|
||||||
@@ -244,13 +231,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
seq.num_cached_tokens = 0
|
seq.num_cached_tokens = 0
|
||||||
seq.block_table.clear()
|
seq.block_table.clear()
|
||||||
|
|
||||||
# Clear decode tracking to prevent state pollution between requests
|
|
||||||
self.clear_decode_tracking(seq)
|
|
||||||
|
|
||||||
# Clear offload engine state (decode buffer, events)
|
|
||||||
if self.offload_engine is not None:
|
|
||||||
self.offload_engine.on_sequence_finished()
|
|
||||||
|
|
||||||
def can_append(self, seq: Sequence) -> bool:
|
def can_append(self, seq: Sequence) -> bool:
|
||||||
"""Check if we can append a token."""
|
"""Check if we can append a token."""
|
||||||
need_new_block = (len(seq) % self._block_size == 1)
|
need_new_block = (len(seq) % self._block_size == 1)
|
||||||
@@ -299,8 +279,8 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
Prepare KV cache for attention computation.
|
Prepare KV cache for attention computation.
|
||||||
|
|
||||||
In layer-wise offload mode, this is a no-op because KV transfers
|
In ring buffer mode, this is a no-op because chunked offload
|
||||||
are handled directly in model_runner's layer-by-layer methods.
|
paths handle H2D transfers directly in the attention layer.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -311,12 +291,12 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
Get GPU slot tables for sequences.
|
Get GPU slot tables for sequences.
|
||||||
|
|
||||||
In layer-wise offload mode, all blocks are on CPU, so this raises an error
|
In ring buffer mode, all blocks are on CPU, so this raises an error
|
||||||
if called. Use run_layerwise_offload_* methods instead.
|
if called. Use run_chunked_offload_* methods instead.
|
||||||
"""
|
"""
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"get_gpu_block_tables should not be called in layer-wise offload mode. "
|
"get_gpu_block_tables should not be called in ring buffer mode. "
|
||||||
"Use run_layerwise_offload_prefill/decode instead."
|
"Use run_chunked_offload_prefill/decode instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
def post_attention_cleanup(
|
def post_attention_cleanup(
|
||||||
@@ -327,18 +307,18 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
"""
|
"""
|
||||||
Cleanup after attention.
|
Cleanup after attention.
|
||||||
|
|
||||||
In layer-wise offload mode, this is a no-op because offload is handled
|
In ring buffer mode, this is a no-op because offload is handled
|
||||||
directly in model_runner's layer-by-layer methods.
|
directly in the chunked prefill/decode paths.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# ========== Layer-wise Offload Support ==========
|
# ========== Ring Buffer CPU-primary Chunked Prefill Support ==========
|
||||||
|
|
||||||
def get_prefilled_cpu_blocks(self, seq: Sequence) -> List[int]:
|
def get_prefilled_cpu_blocks(self, seq: Sequence) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Get list of CPU block IDs for blocks that have been prefilled.
|
Get list of CPU block IDs for blocks that have been prefilled.
|
||||||
|
|
||||||
Used for loading prefilled KV during decode.
|
Used for loading previous KV during chunked prefill.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of CPU block IDs in sequence order
|
List of CPU block IDs in sequence order
|
||||||
@@ -349,19 +329,17 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
block = self.logical_blocks[logical_id]
|
block = self.logical_blocks[logical_id]
|
||||||
if block.location == BlockLocation.CPU:
|
if block.location == BlockLocation.CPU:
|
||||||
cpu_blocks.append(block.cpu_block_id)
|
cpu_blocks.append(block.cpu_block_id)
|
||||||
# DEBUG: Log on first decode call
|
# logger.debug(
|
||||||
logger.debug(
|
# f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
|
||||||
f"[DEBUG] get_prefilled_cpu_blocks: block_table={list(seq.block_table)}, "
|
# f"returned cpu_blocks={cpu_blocks}"
|
||||||
f"prefilled_blocks={list(self.prefilled_blocks)}, "
|
# )
|
||||||
f"returned cpu_blocks={cpu_blocks}"
|
|
||||||
)
|
|
||||||
return cpu_blocks
|
return cpu_blocks
|
||||||
|
|
||||||
# ========== CPU Block Allocation ==========
|
# ========== Ring Buffer CPU-primary support ==========
|
||||||
|
|
||||||
def allocate_cpu_only(self, seq: Sequence) -> None:
|
def allocate_cpu_only(self, seq: Sequence) -> None:
|
||||||
"""
|
"""
|
||||||
Allocate CPU blocks for sequence (for layer-wise offload mode).
|
Allocate CPU blocks for sequence (for ring buffer mode).
|
||||||
|
|
||||||
Unlike allocate(), here all blocks are allocated to CPU,
|
Unlike allocate(), here all blocks are allocated to CPU,
|
||||||
GPU is only used as ring buffer for computation.
|
GPU is only used as ring buffer for computation.
|
||||||
@@ -392,10 +370,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
self.cpu_block_to_logical[cpu_block_id] = logical_id
|
||||||
seq.block_table.append(logical_id)
|
seq.block_table.append(logical_id)
|
||||||
|
|
||||||
# DEBUG: Log allocated CPU blocks
|
|
||||||
cpu_blocks = [self.logical_blocks[lid].cpu_block_id for lid in seq.block_table]
|
|
||||||
logger.debug(f"[DEBUG] allocate_cpu_only: allocated cpu_blocks={cpu_blocks}")
|
|
||||||
|
|
||||||
# NOTE: Prefix cache disabled in offload mode
|
# NOTE: Prefix cache disabled in offload mode
|
||||||
# If enabled, would compute hash and update:
|
# If enabled, would compute hash and update:
|
||||||
# h = self.compute_hash(seq.block(i), prefix_hash)
|
# h = self.compute_hash(seq.block(i), prefix_hash)
|
||||||
@@ -443,8 +417,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
if block.location == BlockLocation.CPU:
|
if block.location == BlockLocation.CPU:
|
||||||
cpu_block_ids.append(block.cpu_block_id)
|
cpu_block_ids.append(block.cpu_block_id)
|
||||||
logical_ids.append(logical_id)
|
logical_ids.append(logical_id)
|
||||||
# DEBUG: Log during prefill
|
|
||||||
logger.debug(f"[DEBUG] get_all_cpu_blocks: returned cpu_block_ids={cpu_block_ids}")
|
|
||||||
return cpu_block_ids, logical_ids
|
return cpu_block_ids, logical_ids
|
||||||
|
|
||||||
def allocate_next_cpu_block(self, seq: Sequence) -> int:
|
def allocate_next_cpu_block(self, seq: Sequence) -> int:
|
||||||
@@ -496,6 +468,20 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
return block.cpu_block_id
|
return block.cpu_block_id
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
|
def get_write_slot_for_chunked_offload(self, seq: Sequence) -> int:
|
||||||
|
"""
|
||||||
|
Get GPU slot for writing new KV during chunked offload decode.
|
||||||
|
|
||||||
|
In ring buffer design, always use decode_slot (slot[0]) to write new KV.
|
||||||
|
This avoids conflicts with loading operations which use slots[1:].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
seq: Sequence
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
GPU slot ID (always decode_slot = 0)
|
||||||
|
"""
|
||||||
|
return self.offload_engine.decode_slot
|
||||||
|
|
||||||
def get_decode_start_pos(self, seq: Sequence) -> int:
|
def get_decode_start_pos(self, seq: Sequence) -> int:
|
||||||
"""
|
"""
|
||||||
@@ -517,12 +503,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
# Decode starts at the next position
|
# Decode starts at the next position
|
||||||
prefill_len = len(seq) - 1 # Current len includes the new decode token
|
prefill_len = len(seq) - 1 # Current len includes the new decode token
|
||||||
self._decode_start_pos[seq_id] = prefill_len % self._block_size
|
self._decode_start_pos[seq_id] = prefill_len % self._block_size
|
||||||
# DEBUG: Log first access
|
|
||||||
logger.debug(
|
|
||||||
f"[DEBUG] get_decode_start_pos FIRST ACCESS: seq_id={seq_id}, "
|
|
||||||
f"len(seq)={len(seq)}, prefill_len={prefill_len}, "
|
|
||||||
f"stored decode_start_pos={self._decode_start_pos[seq_id]}"
|
|
||||||
)
|
|
||||||
return self._decode_start_pos[seq_id]
|
return self._decode_start_pos[seq_id]
|
||||||
|
|
||||||
def reset_decode_start_pos(self, seq: Sequence) -> None:
|
def reset_decode_start_pos(self, seq: Sequence) -> None:
|
||||||
@@ -555,11 +535,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
# First decode step - store the prefill length
|
# First decode step - store the prefill length
|
||||||
# len(seq) - 1 because current len includes the first decode token
|
# len(seq) - 1 because current len includes the first decode token
|
||||||
self._prefill_len[seq_id] = len(seq) - 1
|
self._prefill_len[seq_id] = len(seq) - 1
|
||||||
# DEBUG: Log first access
|
|
||||||
logger.debug(
|
|
||||||
f"[DEBUG] get_prefill_len FIRST ACCESS: seq_id={seq_id}, "
|
|
||||||
f"len(seq)={len(seq)}, stored prefill_len={self._prefill_len[seq_id]}"
|
|
||||||
)
|
|
||||||
return self._prefill_len[seq_id]
|
return self._prefill_len[seq_id]
|
||||||
|
|
||||||
def clear_decode_tracking(self, seq: Sequence) -> None:
|
def clear_decode_tracking(self, seq: Sequence) -> None:
|
||||||
@@ -572,15 +547,6 @@ class HybridKVCacheManager(KVCacheManager):
|
|||||||
seq: Sequence
|
seq: Sequence
|
||||||
"""
|
"""
|
||||||
seq_id = id(seq)
|
seq_id = id(seq)
|
||||||
# DEBUG: Log clearing and CPU blocks
|
|
||||||
cpu_blocks = [self.logical_blocks[lid].cpu_block_id for lid in seq.block_table
|
|
||||||
if self.logical_blocks[lid].location == BlockLocation.CPU]
|
|
||||||
logger.debug(
|
|
||||||
f"[DEBUG] clear_decode_tracking: seq_id={seq_id}, "
|
|
||||||
f"clearing decode_start_pos={self._decode_start_pos.get(seq_id, 'N/A')}, "
|
|
||||||
f"prefill_len={self._prefill_len.get(seq_id, 'N/A')}, "
|
|
||||||
f"cpu_blocks={cpu_blocks}"
|
|
||||||
)
|
|
||||||
self._decode_start_pos.pop(seq_id, None)
|
self._decode_start_pos.pop(seq_id, None)
|
||||||
self._prefill_len.pop(seq_id, None)
|
self._prefill_len.pop(seq_id, None)
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -23,8 +23,6 @@ from nanovllm.config import SparsePolicyType
|
|||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||||
from nanovllm.kvcache.sparse.minference import MInferencePolicy
|
|
||||||
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
|
|
||||||
|
|
||||||
|
|
||||||
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||||
@@ -57,26 +55,6 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
|
|||||||
)
|
)
|
||||||
return QuestPolicy(config)
|
return QuestPolicy(config)
|
||||||
|
|
||||||
elif policy_type == SparsePolicyType.MINFERENCE:
|
|
||||||
return MInferencePolicy(
|
|
||||||
vertical_size=kwargs.get("vertical_size", 1000),
|
|
||||||
slash_size=kwargs.get("slash_size", 6096),
|
|
||||||
adaptive_budget=kwargs.get("adaptive_budget", 0.3),
|
|
||||||
num_sink_tokens=kwargs.get("num_sink_tokens", 30),
|
|
||||||
num_recent_diags=kwargs.get("num_recent_diags", 100),
|
|
||||||
)
|
|
||||||
|
|
||||||
elif policy_type == SparsePolicyType.XATTN:
|
|
||||||
return XAttentionPolicy(
|
|
||||||
stride=kwargs.get("stride", 8),
|
|
||||||
threshold=kwargs.get("threshold", 0.9),
|
|
||||||
chunk_size=kwargs.get("chunk_size", 16384),
|
|
||||||
use_triton=kwargs.get("use_triton", True),
|
|
||||||
keep_sink=kwargs.get("keep_sink", False),
|
|
||||||
keep_recent=kwargs.get("keep_recent", False),
|
|
||||||
norm=kwargs.get("norm", 1.0),
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown policy type: {policy_type}")
|
raise ValueError(f"Unknown policy type: {policy_type}")
|
||||||
|
|
||||||
@@ -89,7 +67,5 @@ __all__ = [
|
|||||||
"QuestPolicy",
|
"QuestPolicy",
|
||||||
"QuestConfig",
|
"QuestConfig",
|
||||||
"BlockMetadataManager",
|
"BlockMetadataManager",
|
||||||
"MInferencePolicy",
|
|
||||||
"XAttentionPolicy",
|
|
||||||
"create_sparse_policy",
|
"create_sparse_policy",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
# Full attention supports both prefill and decode
|
# Full attention supports both prefill and decode
|
||||||
supports_prefill = True
|
supports_prefill = True
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
requires_block_selection = False # Load all blocks, no selective loading
|
|
||||||
|
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -1,320 +0,0 @@
|
|||||||
"""
|
|
||||||
Triton kernels for XAttention sparse attention.
|
|
||||||
|
|
||||||
Copied and adapted from COMPASS/compass/src/kernels.py
|
|
||||||
for XAttention integration in nano-vllm.
|
|
||||||
|
|
||||||
Requirements:
|
|
||||||
- Triton >= 2.1.0
|
|
||||||
- CUDA compute capability SM 80+ (RTX 3090, A100, H100, etc.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import math
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def softmax_fuse_block_sum_kernel_causal(
|
|
||||||
In,
|
|
||||||
Out,
|
|
||||||
scale,
|
|
||||||
input_stride_0,
|
|
||||||
input_stride_1,
|
|
||||||
input_stride_2,
|
|
||||||
output_stride_0,
|
|
||||||
output_stride_1,
|
|
||||||
output_stride_2,
|
|
||||||
real_q_len,
|
|
||||||
k_len,
|
|
||||||
chunk_start,
|
|
||||||
chunk_end,
|
|
||||||
segment_size: tl.constexpr,
|
|
||||||
block_size: tl.constexpr,
|
|
||||||
):
|
|
||||||
block_id = tl.program_id(0)
|
|
||||||
head_id = tl.program_id(1)
|
|
||||||
batch_id = tl.program_id(2)
|
|
||||||
|
|
||||||
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
|
||||||
offs_k = tl.arange(0, segment_size)
|
|
||||||
|
|
||||||
num_iters = k_len // segment_size
|
|
||||||
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size
|
|
||||||
|
|
||||||
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
|
||||||
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0
|
|
||||||
|
|
||||||
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
|
||||||
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
|
||||||
|
|
||||||
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
|
||||||
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
|
||||||
|
|
||||||
for iter in range(0, num_iters_before_causal):
|
|
||||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
||||||
m_local = tl.max(X, 1)
|
|
||||||
m_new = tl.maximum(m_i, m_local)
|
|
||||||
alpha = tl.math.exp2(m_i - m_new)
|
|
||||||
|
|
||||||
X = X - m_new[:, None]
|
|
||||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
|
||||||
l_i = l_i * alpha + l_local
|
|
||||||
|
|
||||||
m_i = m_new
|
|
||||||
|
|
||||||
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
|
||||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
||||||
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
|
|
||||||
X = tl.where(mask, X, -1.0e6)
|
|
||||||
m_local = tl.max(X, 1)
|
|
||||||
m_new = tl.maximum(m_i, m_local)
|
|
||||||
alpha = tl.math.exp2(m_i - m_new)
|
|
||||||
|
|
||||||
X = X - m_new[:, None]
|
|
||||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
|
||||||
l_i = l_i * alpha + l_local
|
|
||||||
|
|
||||||
m_i = m_new
|
|
||||||
|
|
||||||
l_i_inv = 1.0 / l_i
|
|
||||||
|
|
||||||
sum_mask = offs_q[:, None] < real_q_len
|
|
||||||
|
|
||||||
for iter in range(0, num_iters_before_causal):
|
|
||||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
||||||
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
|
||||||
X = tl.where(sum_mask, X, 0)
|
|
||||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
|
||||||
X = tl.sum(X, 2)
|
|
||||||
X = tl.sum(X, 0)
|
|
||||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
|
||||||
|
|
||||||
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
|
||||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
||||||
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
|
|
||||||
X = tl.where(mask, X, -1.0e6)
|
|
||||||
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
|
||||||
X = tl.where(sum_mask, X, 0)
|
|
||||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
|
||||||
X = tl.sum(X, 2)
|
|
||||||
X = tl.sum(X, 0)
|
|
||||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
|
||||||
|
|
||||||
for iter in range(num_iters_before_causal + 1, num_iters):
|
|
||||||
X = tl.zeros([segment_size // block_size], dtype=tl.float32)
|
|
||||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def softmax_fuse_block_sum_kernel_non_causal(
|
|
||||||
In,
|
|
||||||
Out,
|
|
||||||
scale,
|
|
||||||
input_stride_0,
|
|
||||||
input_stride_1,
|
|
||||||
input_stride_2,
|
|
||||||
output_stride_0,
|
|
||||||
output_stride_1,
|
|
||||||
output_stride_2,
|
|
||||||
real_q_len,
|
|
||||||
k_len,
|
|
||||||
chunk_start,
|
|
||||||
chunk_end,
|
|
||||||
segment_size: tl.constexpr,
|
|
||||||
block_size: tl.constexpr,
|
|
||||||
):
|
|
||||||
block_id = tl.program_id(0)
|
|
||||||
head_id = tl.program_id(1)
|
|
||||||
batch_id = tl.program_id(2)
|
|
||||||
|
|
||||||
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
|
||||||
offs_k = tl.arange(0, segment_size)
|
|
||||||
|
|
||||||
num_iters = k_len // segment_size
|
|
||||||
|
|
||||||
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
|
||||||
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0
|
|
||||||
|
|
||||||
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
|
||||||
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
|
||||||
|
|
||||||
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
|
||||||
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
|
||||||
|
|
||||||
for iter in range(0, num_iters):
|
|
||||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
||||||
m_local = tl.max(X, 1)
|
|
||||||
m_new = tl.maximum(m_i, m_local)
|
|
||||||
alpha = tl.math.exp2(m_i - m_new)
|
|
||||||
|
|
||||||
X = X - m_new[:, None]
|
|
||||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
|
||||||
l_i = l_i * alpha + l_local
|
|
||||||
|
|
||||||
m_i = m_new
|
|
||||||
|
|
||||||
l_i_inv = 1.0 / l_i
|
|
||||||
|
|
||||||
sum_mask = offs_q[:, None] < real_q_len
|
|
||||||
|
|
||||||
for iter in range(0, num_iters):
|
|
||||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
||||||
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
|
||||||
X = tl.where(sum_mask, X, 0)
|
|
||||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
|
||||||
X = tl.sum(X, 2)
|
|
||||||
X = tl.sum(X, 0)
|
|
||||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def flat_group_gemm_fuse_reshape_kernel(Q, K, Out,
|
|
||||||
stride_qz, stride_qh, stride_qn,
|
|
||||||
stride_kz, stride_kh, stride_kn,
|
|
||||||
stride_oz, stride_oh, stride_on,
|
|
||||||
chunk_start, chunk_end,
|
|
||||||
H: tl.constexpr,
|
|
||||||
STRIDE: tl.constexpr,
|
|
||||||
HEAD_DIM: tl.constexpr,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
is_causal: tl.constexpr,
|
|
||||||
):
|
|
||||||
block_m = tl.program_id(0).to(tl.int64)
|
|
||||||
block_n = tl.program_id(1).to(tl.int64)
|
|
||||||
batch_id = tl.program_id(2).to(tl.int64) // H
|
|
||||||
head_id = tl.program_id(2).to(tl.int64) % H
|
|
||||||
|
|
||||||
if is_causal:
|
|
||||||
if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N:
|
|
||||||
return
|
|
||||||
|
|
||||||
Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn
|
|
||||||
K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn
|
|
||||||
|
|
||||||
Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + tl.arange(0, HEAD_DIM)[None, :] + stride_qn * (STRIDE - 1)
|
|
||||||
K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + tl.arange(0, HEAD_DIM)[:, None]
|
|
||||||
|
|
||||||
o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
||||||
|
|
||||||
for iter in range(STRIDE):
|
|
||||||
q = tl.load(Q_ptrs - iter * stride_qn)
|
|
||||||
k = tl.load(K_ptrs + iter * stride_kn)
|
|
||||||
o += tl.dot(q, k)
|
|
||||||
|
|
||||||
O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N
|
|
||||||
O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :]
|
|
||||||
|
|
||||||
tl.store(O_ptrs, o.to(Out.type.element_ty))
|
|
||||||
|
|
||||||
|
|
||||||
def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size, segment_size, chunk_start, chunk_end, real_q_len, scale, is_causal=True):
|
|
||||||
"""Wrapper for Triton softmax-fuse-block-sum kernel."""
|
|
||||||
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
|
||||||
assert q_len % reshaped_block_size == 0
|
|
||||||
assert k_len % segment_size == 0
|
|
||||||
assert segment_size % reshaped_block_size == 0
|
|
||||||
assert attn_weights_slice.stride(-1) == 1
|
|
||||||
|
|
||||||
output = torch.empty(
|
|
||||||
(batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size),
|
|
||||||
dtype=attn_weights_slice.dtype,
|
|
||||||
device=attn_weights_slice.device
|
|
||||||
)
|
|
||||||
|
|
||||||
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
|
||||||
|
|
||||||
if is_causal:
|
|
||||||
softmax_fuse_block_sum_kernel_causal[grid](
|
|
||||||
attn_weights_slice,
|
|
||||||
output,
|
|
||||||
scale,
|
|
||||||
attn_weights_slice.stride(0),
|
|
||||||
attn_weights_slice.stride(1),
|
|
||||||
attn_weights_slice.stride(2),
|
|
||||||
output.stride(0),
|
|
||||||
output.stride(1),
|
|
||||||
output.stride(2),
|
|
||||||
real_q_len,
|
|
||||||
k_len,
|
|
||||||
chunk_start,
|
|
||||||
chunk_end,
|
|
||||||
segment_size,
|
|
||||||
reshaped_block_size,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
softmax_fuse_block_sum_kernel_non_causal[grid](
|
|
||||||
attn_weights_slice,
|
|
||||||
output,
|
|
||||||
scale,
|
|
||||||
attn_weights_slice.stride(0),
|
|
||||||
attn_weights_slice.stride(1),
|
|
||||||
attn_weights_slice.stride(2),
|
|
||||||
output.stride(0),
|
|
||||||
output.stride(1),
|
|
||||||
output.stride(2),
|
|
||||||
real_q_len,
|
|
||||||
k_len,
|
|
||||||
chunk_start,
|
|
||||||
chunk_end,
|
|
||||||
segment_size,
|
|
||||||
reshaped_block_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def flat_group_gemm_fuse_reshape(query_states, key_states, stride, chunk_start, chunk_end, is_causal=True):
|
|
||||||
"""Wrapper for Triton flat-group-gemm-fuse-reshape kernel."""
|
|
||||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
|
||||||
kv_len = key_states.shape[2]
|
|
||||||
|
|
||||||
assert key_states.shape[0] == batch_size
|
|
||||||
assert key_states.shape[1] == num_heads
|
|
||||||
assert key_states.shape[3] == head_dim
|
|
||||||
|
|
||||||
output = torch.empty(
|
|
||||||
(batch_size, num_heads, q_len // stride, kv_len // stride),
|
|
||||||
dtype=query_states.dtype,
|
|
||||||
device=query_states.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Adjust block size based on GPU shared memory
|
|
||||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
||||||
if props.total_memory < 30 * 1024**3: # Less than 30GB (e.g., RTX 3090 24GB)
|
|
||||||
BLOCK_M = 64
|
|
||||||
BLOCK_N = 64
|
|
||||||
else:
|
|
||||||
BLOCK_M = 128
|
|
||||||
BLOCK_N = 128
|
|
||||||
|
|
||||||
assert q_len % (stride * BLOCK_M) == 0
|
|
||||||
assert kv_len % (stride * BLOCK_N) == 0
|
|
||||||
|
|
||||||
grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads)
|
|
||||||
flat_group_gemm_fuse_reshape_kernel[grid](
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
output,
|
|
||||||
query_states.stride(0),
|
|
||||||
query_states.stride(1),
|
|
||||||
query_states.stride(2),
|
|
||||||
key_states.stride(0),
|
|
||||||
key_states.stride(1),
|
|
||||||
key_states.stride(2),
|
|
||||||
output.stride(0),
|
|
||||||
output.stride(1),
|
|
||||||
output.stride(2),
|
|
||||||
chunk_start,
|
|
||||||
chunk_end,
|
|
||||||
num_heads,
|
|
||||||
stride,
|
|
||||||
head_dim,
|
|
||||||
BLOCK_M,
|
|
||||||
BLOCK_N,
|
|
||||||
is_causal,
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
@@ -1,354 +0,0 @@
|
|||||||
"""
|
|
||||||
MInference sparse attention policy.
|
|
||||||
|
|
||||||
Implements vertical + slash sparse pattern estimation using the last 64 query tokens.
|
|
||||||
Reference: MInference paper (https://arxiv.org/abs/2407.02490)
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import List, Tuple, Optional
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
|
||||||
|
|
||||||
|
|
||||||
class MInferencePolicy(SparsePolicy):
|
|
||||||
"""
|
|
||||||
MInference sparse prefill policy using vertical + slash pattern.
|
|
||||||
|
|
||||||
This policy estimates sparse attention patterns by analyzing attention
|
|
||||||
scores from the last 64 query tokens, then selects:
|
|
||||||
- Vertical: Key positions that are important across all queries
|
|
||||||
- Slash: Diagonal bands (local context)
|
|
||||||
|
|
||||||
The estimated pattern is then used to compute sparse attention.
|
|
||||||
|
|
||||||
Note: This policy is designed for GPU-only prefill. For CPU offload,
|
|
||||||
the pattern estimation and sparse attention will be handled differently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
supports_prefill = True
|
|
||||||
supports_decode = False # MInference is prefill-only sparse strategy
|
|
||||||
requires_block_selection = False # MInference only affects attention computation, not KV load
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
vertical_size: int = 1000,
|
|
||||||
slash_size: int = 6096,
|
|
||||||
adaptive_budget: Optional[float] = 0.3,
|
|
||||||
num_sink_tokens: int = 30,
|
|
||||||
num_recent_diags: int = 100,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize MInference policy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
vertical_size: Number of vertical (column) positions to keep
|
|
||||||
slash_size: Number of diagonal bands to keep
|
|
||||||
adaptive_budget: If set, compute budget as fraction of seq_len
|
|
||||||
(overrides vertical_size and slash_size)
|
|
||||||
num_sink_tokens: Number of initial sink tokens to always keep
|
|
||||||
num_recent_diags: Number of recent diagonals to always keep
|
|
||||||
"""
|
|
||||||
self.vertical_size = vertical_size
|
|
||||||
self.slash_size = slash_size
|
|
||||||
self.adaptive_budget = adaptive_budget
|
|
||||||
self.num_sink_tokens = num_sink_tokens
|
|
||||||
self.num_recent_diags = num_recent_diags
|
|
||||||
|
|
||||||
# Cache for last-q causal mask
|
|
||||||
self._last_q_mask_cache: dict = {}
|
|
||||||
|
|
||||||
def _get_causal_mask(self, last_q: int, seq_len: int, device: torch.device) -> torch.Tensor:
|
|
||||||
"""Get causal mask for last-q attention."""
|
|
||||||
cache_key = (last_q, seq_len, device)
|
|
||||||
if cache_key not in self._last_q_mask_cache:
|
|
||||||
# Create mask where last_q queries can attend to all previous positions
|
|
||||||
# Shape: [last_q, seq_len]
|
|
||||||
mask = torch.ones(last_q, seq_len, device=device, dtype=torch.bool)
|
|
||||||
# Apply causal constraint for the last last_q positions
|
|
||||||
# Query i (from last_q) can only attend to positions <= (seq_len - last_q + i)
|
|
||||||
for i in range(last_q):
|
|
||||||
mask[i, seq_len - last_q + i + 1:] = False
|
|
||||||
self._last_q_mask_cache[cache_key] = mask
|
|
||||||
return self._last_q_mask_cache[cache_key]
|
|
||||||
|
|
||||||
def estimate_pattern(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Estimate vertical + slash sparse pattern using last 64 query tokens.
|
|
||||||
Memory-optimized for long sequences (64K+).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Current layer index (for potential layer-specific patterns)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (vertical_indices, slash_indices):
|
|
||||||
- vertical_indices: [num_heads, vertical_size] - important K positions
|
|
||||||
- slash_indices: [num_heads, slash_size] - diagonal offsets
|
|
||||||
"""
|
|
||||||
seq_len = q.shape[0]
|
|
||||||
num_heads = q.shape[1]
|
|
||||||
head_dim = q.shape[2]
|
|
||||||
num_kv_heads = k.shape[1]
|
|
||||||
|
|
||||||
# Adaptive budget
|
|
||||||
if self.adaptive_budget is not None:
|
|
||||||
budget = int(seq_len * self.adaptive_budget)
|
|
||||||
vertical_size = max(self.num_sink_tokens + 1, int(budget * 0.2))
|
|
||||||
slash_size = max(self.num_recent_diags + 1, int(budget * 0.8))
|
|
||||||
else:
|
|
||||||
vertical_size = self.vertical_size
|
|
||||||
slash_size = self.slash_size
|
|
||||||
|
|
||||||
# Use last 64 Q tokens for estimation
|
|
||||||
last_q = min(64, seq_len)
|
|
||||||
q_last = q[-last_q:] # [last_q, heads, dim] - this is a view, not a copy
|
|
||||||
|
|
||||||
# Handle GQA: if num_kv_heads < num_heads, we need to expand K
|
|
||||||
if num_kv_heads < num_heads:
|
|
||||||
num_groups = num_heads // num_kv_heads
|
|
||||||
k_work = k.repeat_interleave(num_groups, dim=1)
|
|
||||||
else:
|
|
||||||
k_work = k
|
|
||||||
|
|
||||||
# Compute attention scores: [heads, last_q, seq_len]
|
|
||||||
scale = 1.0 / math.sqrt(head_dim)
|
|
||||||
qk = torch.einsum('qhd,khd->hqk', q_last, k_work) * scale
|
|
||||||
|
|
||||||
# Free k_work if it was a copy
|
|
||||||
if num_kv_heads < num_heads:
|
|
||||||
del k_work
|
|
||||||
|
|
||||||
# Apply causal mask for last positions (in-place)
|
|
||||||
causal_mask = self._get_causal_mask(last_q, seq_len, q.device)
|
|
||||||
qk.masked_fill_(~causal_mask.unsqueeze(0), float('-inf'))
|
|
||||||
|
|
||||||
# Softmax (in-place where possible)
|
|
||||||
qk = F.softmax(qk, dim=-1, dtype=torch.float32)
|
|
||||||
|
|
||||||
# === Vertical pattern ===
|
|
||||||
# Sum across query dimension -> importance of each K position
|
|
||||||
vertical_scores = qk.sum(dim=1) # [heads, seq_len]
|
|
||||||
|
|
||||||
# Force keep first num_sink_tokens (attention sinks) - in-place
|
|
||||||
vertical_scores[:, :self.num_sink_tokens] = float('inf')
|
|
||||||
|
|
||||||
# Select top-k
|
|
||||||
actual_vertical = min(vertical_size, seq_len)
|
|
||||||
vertical_indices = vertical_scores.topk(actual_vertical, dim=-1).indices
|
|
||||||
vertical_indices = vertical_indices.sort(dim=-1).values
|
|
||||||
del vertical_scores
|
|
||||||
|
|
||||||
# === Slash pattern ===
|
|
||||||
# Create diagonal index matrix: [last_q, seq_len] with int32 to save memory
|
|
||||||
q_indices = torch.arange(last_q, device=q.device, dtype=torch.int32).unsqueeze(1)
|
|
||||||
k_indices = torch.arange(seq_len, device=q.device, dtype=torch.int32).unsqueeze(0)
|
|
||||||
diag_indices = (seq_len - last_q + q_indices) - k_indices # [last_q, seq_len]
|
|
||||||
del q_indices
|
|
||||||
|
|
||||||
# Create causal mask for slash computation
|
|
||||||
q_pos = seq_len - last_q + torch.arange(last_q, device=q.device, dtype=torch.int32).unsqueeze(1)
|
|
||||||
slash_causal_mask = k_indices <= q_pos
|
|
||||||
del q_pos, k_indices
|
|
||||||
|
|
||||||
# Clamp diagonal indices to valid range
|
|
||||||
diag_indices = diag_indices.clamp(0, seq_len - 1)
|
|
||||||
|
|
||||||
# Apply causal mask to qk (in-place) for slash computation
|
|
||||||
qk[:, ~slash_causal_mask] = 0
|
|
||||||
del slash_causal_mask
|
|
||||||
|
|
||||||
# Accumulate scores per diagonal - process in batches to save memory
|
|
||||||
slash_scores = torch.zeros(num_heads, seq_len, device=q.device, dtype=torch.float32)
|
|
||||||
|
|
||||||
# Process heads in chunks to reduce peak memory for diag_indices_expanded
|
|
||||||
chunk_size = min(8, num_heads) # Process 8 heads at a time
|
|
||||||
for h_start in range(0, num_heads, chunk_size):
|
|
||||||
h_end = min(h_start + chunk_size, num_heads)
|
|
||||||
n_heads_chunk = h_end - h_start
|
|
||||||
|
|
||||||
# Expand diag_indices only for this chunk
|
|
||||||
diag_chunk = diag_indices.unsqueeze(0).expand(n_heads_chunk, -1, -1).long()
|
|
||||||
qk_chunk = qk[h_start:h_end]
|
|
||||||
|
|
||||||
slash_scores[h_start:h_end].scatter_add_(
|
|
||||||
1,
|
|
||||||
diag_chunk.reshape(n_heads_chunk, -1),
|
|
||||||
qk_chunk.reshape(n_heads_chunk, -1)
|
|
||||||
)
|
|
||||||
del diag_chunk, qk_chunk
|
|
||||||
|
|
||||||
del diag_indices, qk
|
|
||||||
|
|
||||||
# Force keep first num_recent_diags (in-place)
|
|
||||||
slash_scores[:, :self.num_recent_diags] = float('inf')
|
|
||||||
|
|
||||||
# Select top-k diagonal indices
|
|
||||||
actual_slash = min(slash_size, seq_len)
|
|
||||||
slash_indices = slash_scores.topk(actual_slash, dim=-1).indices
|
|
||||||
slash_indices = slash_indices.sort(dim=-1).values
|
|
||||||
del slash_scores
|
|
||||||
|
|
||||||
return vertical_indices, slash_indices
|
|
||||||
|
|
||||||
def select_blocks(
|
|
||||||
self,
|
|
||||||
available_blocks: List[int],
|
|
||||||
ctx: PolicyContext,
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
|
||||||
Select blocks for chunked CPU offload mode.
|
|
||||||
|
|
||||||
For MInference in GPU-only mode, this method is not used.
|
|
||||||
In CPU offload mode, it would select blocks based on the sparse pattern.
|
|
||||||
|
|
||||||
For now, return all blocks (full attention fallback).
|
|
||||||
"""
|
|
||||||
# MInference pattern is computed in attention.forward()
|
|
||||||
# For CPU offload integration (Phase B), this would use the pattern
|
|
||||||
return available_blocks
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset policy state."""
|
|
||||||
self._last_q_mask_cache.clear()
|
|
||||||
|
|
||||||
def sparse_prefill_attention(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute MInference sparse attention for prefill.
|
|
||||||
|
|
||||||
Uses vertical + slash pattern to compute sparse attention efficiently.
|
|
||||||
Memory-optimized to handle long sequences (64K+) by freeing intermediate tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Current transformer layer index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
from minference.ops.pit_sparse_flash_attention_v2 import _triton_mixed_sparse_attention
|
|
||||||
from minference.cuda import convert_vertical_slash_indexes
|
|
||||||
|
|
||||||
seq_len = q.shape[0]
|
|
||||||
num_heads = q.shape[1]
|
|
||||||
head_dim = q.shape[2]
|
|
||||||
num_kv_heads = k.shape[1]
|
|
||||||
|
|
||||||
# Estimate sparse pattern (uses temporary memory for qk scores)
|
|
||||||
vertical_indices, slash_indices = self.estimate_pattern(q, k, layer_id)
|
|
||||||
# Free any cached memory from pattern estimation
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Triton sparse attention kernel parameters
|
|
||||||
block_size_M = 64
|
|
||||||
block_size_N = 64
|
|
||||||
|
|
||||||
# Calculate padding
|
|
||||||
pad = (block_size_M - seq_len) & (block_size_M - 1)
|
|
||||||
need_head_pad = head_dim not in [16, 32, 64, 128, 256, 512]
|
|
||||||
head_pad = (2 ** math.ceil(math.log2(head_dim)) - head_dim) if need_head_pad else 0
|
|
||||||
|
|
||||||
# Handle GQA: expand K/V to match query heads
|
|
||||||
# Do this BEFORE creating batched tensors to avoid double copies
|
|
||||||
if num_kv_heads < num_heads:
|
|
||||||
num_groups = num_heads // num_kv_heads
|
|
||||||
# Use repeat_interleave for memory-efficient expansion
|
|
||||||
k_work = k.repeat_interleave(num_groups, dim=1)
|
|
||||||
v_work = v.repeat_interleave(num_groups, dim=1)
|
|
||||||
else:
|
|
||||||
k_work = k
|
|
||||||
v_work = v
|
|
||||||
|
|
||||||
# Transform Q to [batch, heads, seq, dim] format with padding in one step
|
|
||||||
# This avoids creating intermediate copies
|
|
||||||
if pad > 0 or head_pad > 0:
|
|
||||||
q_batched = torch.nn.functional.pad(
|
|
||||||
q.unsqueeze(0).transpose(1, 2),
|
|
||||||
[0, head_pad, 0, pad, 0, 0, 0, 0]
|
|
||||||
).contiguous()
|
|
||||||
else:
|
|
||||||
q_batched = q.unsqueeze(0).transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
# Transform K to batched format
|
|
||||||
if pad > 0 or head_pad > 0:
|
|
||||||
k_batched = torch.nn.functional.pad(
|
|
||||||
k_work.unsqueeze(0).transpose(1, 2),
|
|
||||||
[0, head_pad, 0, pad, 0, 0, 0, 0]
|
|
||||||
).contiguous()
|
|
||||||
else:
|
|
||||||
k_batched = k_work.unsqueeze(0).transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
# Free k_work if it was a copy (GQA case)
|
|
||||||
if num_kv_heads < num_heads:
|
|
||||||
del k_work
|
|
||||||
|
|
||||||
# Transform V to batched format
|
|
||||||
if pad > 0 or head_pad > 0:
|
|
||||||
v_batched = torch.nn.functional.pad(
|
|
||||||
v_work.unsqueeze(0).transpose(1, 2),
|
|
||||||
[0, head_pad, 0, pad, 0, 0, 0, 0]
|
|
||||||
).contiguous()
|
|
||||||
else:
|
|
||||||
v_batched = v_work.unsqueeze(0).transpose(1, 2).contiguous()
|
|
||||||
|
|
||||||
# Free v_work if it was a copy (GQA case)
|
|
||||||
if num_kv_heads < num_heads:
|
|
||||||
del v_work
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Prepare indices for Triton kernel
|
|
||||||
v_idx = vertical_indices.to(torch.int32).reshape((1, num_heads, -1))
|
|
||||||
v_idx = v_idx.sort(dim=-1, descending=False)[0].contiguous()
|
|
||||||
del vertical_indices
|
|
||||||
|
|
||||||
s_idx = slash_indices.to(torch.int32).reshape((1, num_heads, -1))
|
|
||||||
s_idx = s_idx.sort(dim=-1, descending=True)[0].contiguous()
|
|
||||||
del slash_indices
|
|
||||||
|
|
||||||
seqlens = torch.tensor([seq_len], dtype=torch.int32, device=q.device)
|
|
||||||
sm_scale = head_dim ** -0.5
|
|
||||||
|
|
||||||
# Convert vertical+slash indices to block sparse format
|
|
||||||
block_count, block_offset, column_count, column_index = convert_vertical_slash_indexes(
|
|
||||||
seqlens, v_idx, s_idx, seq_len, block_size_M, block_size_N,
|
|
||||||
)
|
|
||||||
del v_idx, s_idx
|
|
||||||
|
|
||||||
# Call Triton mixed sparse attention kernel
|
|
||||||
o = _triton_mixed_sparse_attention(
|
|
||||||
q_batched, k_batched, v_batched, seqlens,
|
|
||||||
block_count, block_offset, column_count, column_index,
|
|
||||||
sm_scale, block_size_M, block_size_N,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Free input tensors immediately after kernel call
|
|
||||||
del q_batched, k_batched, v_batched
|
|
||||||
del block_count, block_offset, column_count, column_index
|
|
||||||
|
|
||||||
# Remove padding and convert back to [seq_len, num_heads, head_dim]
|
|
||||||
o = o[..., :seq_len, :head_dim]
|
|
||||||
o = o.transpose(1, 2).squeeze(0).contiguous()
|
|
||||||
|
|
||||||
return o
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (f"MInferencePolicy("
|
|
||||||
f"adaptive_budget={self.adaptive_budget}, "
|
|
||||||
f"vertical_size={self.vertical_size}, "
|
|
||||||
f"slash_size={self.slash_size})")
|
|
||||||
@@ -77,12 +77,6 @@ class SparsePolicy(ABC):
|
|||||||
supports_prefill: bool = True
|
supports_prefill: bool = True
|
||||||
supports_decode: bool = True
|
supports_decode: bool = True
|
||||||
|
|
||||||
# Whether this policy requires selective block loading during decode
|
|
||||||
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
|
|
||||||
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
|
|
||||||
# Example: MInference=False (only affects attention), Quest=True (affects load)
|
|
||||||
requires_block_selection: bool = False
|
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
self,
|
self,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
@@ -189,32 +183,5 @@ class SparsePolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def sparse_prefill_attention(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute sparse attention for prefill phase.
|
|
||||||
|
|
||||||
This method is called when supports_prefill=True and the policy
|
|
||||||
is used for GPU-only sparse prefill (no CPU offload).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Current transformer layer index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
|
|
||||||
"Set supports_prefill=False or implement this method."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.__class__.__name__}()"
|
return f"{self.__class__.__name__}()"
|
||||||
|
|||||||
@@ -158,7 +158,6 @@ class QuestPolicy(SparsePolicy):
|
|||||||
# Quest is decode-only
|
# Quest is decode-only
|
||||||
supports_prefill = False
|
supports_prefill = False
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
requires_block_selection = True # Quest affects KV load strategy (selective block loading)
|
|
||||||
|
|
||||||
def __init__(self, config: QuestConfig):
|
def __init__(self, config: QuestConfig):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,156 +0,0 @@
|
|||||||
"""
|
|
||||||
Utility functions for sparse attention policies.
|
|
||||||
|
|
||||||
Copied from COMPASS/compass/src/utils.py for XAttention integration.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
def find_blocks_chunked(
|
|
||||||
input_tensor, current_index, threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Finds and selects relevant blocks of attention for transformer-based models based on a
|
|
||||||
threshold or a predefined number of blocks.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
- input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, chunk_num, block_num).
|
|
||||||
- current_index (int): The current index in the sequence processing.
|
|
||||||
- threshold (float or None): A threshold value used to determine the minimum attention weight sum.
|
|
||||||
- num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval.
|
|
||||||
- decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode.
|
|
||||||
- mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'.
|
|
||||||
- causal (bool): If True, applies causal masking to prevent future information leakage.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
- torch.Tensor: A boolean mask of shape (batch_size, head_num, chunk_num, block_num),
|
|
||||||
indicating which blocks should be attended to.
|
|
||||||
"""
|
|
||||||
assert threshold is None or num_to_choose is None
|
|
||||||
batch_size, head_num, chunk_num, block_num = input_tensor.shape
|
|
||||||
|
|
||||||
if mode == "prefill" and decoding:
|
|
||||||
return torch.ones_like(input_tensor, dtype=torch.bool)
|
|
||||||
if mode == "decode" and not decoding:
|
|
||||||
mask = torch.ones_like(input_tensor, dtype=torch.bool)
|
|
||||||
if causal:
|
|
||||||
mask[:, :, :, current_index : current_index + chunk_num] = torch.tril(
|
|
||||||
torch.ones(1, head_num, chunk_num, chunk_num, device=input_tensor.device)
|
|
||||||
)
|
|
||||||
mask[:, :, current_index + chunk_num :, :] = 0
|
|
||||||
return torch.cat(
|
|
||||||
[
|
|
||||||
torch.ones_like(input_tensor, dtype=torch.bool)[:, :, 0 : current_index + 1],
|
|
||||||
torch.zeros_like(input_tensor, dtype=torch.bool)[:, :, current_index + 1 :],
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
return mask
|
|
||||||
|
|
||||||
input_tensor = input_tensor.to(float)
|
|
||||||
|
|
||||||
if threshold is not None:
|
|
||||||
total_sum = input_tensor.sum(dim=-1, keepdim=True)
|
|
||||||
if isinstance(threshold, torch.Tensor):
|
|
||||||
threshold = threshold.to(float)
|
|
||||||
required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze(
|
|
||||||
-1
|
|
||||||
).expand((batch_size, head_num, chunk_num, 1)).to(input_tensor.device)
|
|
||||||
else:
|
|
||||||
required_sum = total_sum * threshold
|
|
||||||
|
|
||||||
if causal:
|
|
||||||
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
|
||||||
mask[:, :, :, 0] = 1
|
|
||||||
mask[:, :, :, current_index : current_index + chunk_num] = (
|
|
||||||
torch.eye(chunk_num, device=mask.device)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.unsqueeze(0)
|
|
||||||
.expand(1, head_num, chunk_num, chunk_num)
|
|
||||||
)
|
|
||||||
other_values = input_tensor.masked_fill(mask, 0)
|
|
||||||
sorted_values, _ = torch.sort(
|
|
||||||
other_values, dim=-1, descending=True
|
|
||||||
)
|
|
||||||
sorted_values = sorted_values.to(input_tensor.device)
|
|
||||||
|
|
||||||
sorted_values = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
|
||||||
),
|
|
||||||
torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True),
|
|
||||||
sorted_values[:, :, :, :-2],
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
_, index = torch.sort(
|
|
||||||
torch.where(mask, 100000 * (1 + input_tensor), input_tensor),
|
|
||||||
dim=-1,
|
|
||||||
descending=True
|
|
||||||
)
|
|
||||||
cumulative_sum_without_self = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
|
||||||
),
|
|
||||||
sorted_values[:, :, :, 0:-1],
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
).cumsum(dim=-1)
|
|
||||||
|
|
||||||
index_mask = cumulative_sum_without_self < required_sum
|
|
||||||
index = torch.where(index_mask, index, 0)
|
|
||||||
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
|
||||||
index = index.view(batch_size, head_num * chunk_num, block_num)
|
|
||||||
mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True
|
|
||||||
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
|
||||||
else:
|
|
||||||
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
|
||||||
sorted_values, index = torch.sort(
|
|
||||||
input_tensor, dim=-1, descending=True
|
|
||||||
)
|
|
||||||
sorted_values = sorted_values.to(input_tensor.device)
|
|
||||||
cumulative_sum_without_self = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
|
||||||
),
|
|
||||||
sorted_values[:, :, :, 0:-1],
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
).cumsum(dim=-1)
|
|
||||||
index_mask = cumulative_sum_without_self < required_sum
|
|
||||||
index = torch.where(index_mask, index, 0)
|
|
||||||
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
|
||||||
index = index.view(batch_size, head_num * chunk_num, block_num)
|
|
||||||
mask[
|
|
||||||
:,
|
|
||||||
torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1),
|
|
||||||
index,
|
|
||||||
] = True
|
|
||||||
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("block num chunk prefill not implemented")
|
|
||||||
|
|
||||||
try:
|
|
||||||
if causal:
|
|
||||||
assert (~mask[:, :, :, current_index + chunk_num :]).all()
|
|
||||||
except:
|
|
||||||
mask[:, :, :, current_index + chunk_num :] = False
|
|
||||||
|
|
||||||
if causal:
|
|
||||||
if decoding:
|
|
||||||
assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all()
|
|
||||||
else:
|
|
||||||
lambda_mask = torch.zeros_like(input_tensor, dtype=bool, device=input_tensor.device)
|
|
||||||
lambda_mask[:, :, :, 0] = 1
|
|
||||||
lambda_mask[:, :, :, current_index:current_index+chunk_num] = torch.eye(
|
|
||||||
chunk_num, device=lambda_mask.device
|
|
||||||
).unsqueeze(0).unsqueeze(0).expand(1, head_num, chunk_num, chunk_num)
|
|
||||||
assert(torch.where(lambda_mask, mask, True).all())
|
|
||||||
|
|
||||||
return mask
|
|
||||||
@@ -1,464 +0,0 @@
|
|||||||
"""
|
|
||||||
XAttention sparse attention policy for nano-vllm.
|
|
||||||
|
|
||||||
Implements the XAttention algorithm from COMPASS, using chunked estimation
|
|
||||||
and block sparse attention for efficient long-context inference.
|
|
||||||
|
|
||||||
Reference: COMPASS/compass/src/Xattention.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
from typing import List, Optional
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
|
||||||
from nanovllm.kvcache.sparse.kernels import (
|
|
||||||
flat_group_gemm_fuse_reshape,
|
|
||||||
softmax_fuse_block_sum,
|
|
||||||
)
|
|
||||||
from nanovllm.kvcache.sparse.utils import find_blocks_chunked
|
|
||||||
|
|
||||||
|
|
||||||
class XAttentionPolicy(SparsePolicy):
|
|
||||||
"""
|
|
||||||
XAttention sparse prefill policy using chunked estimation + block sparse attention.
|
|
||||||
|
|
||||||
This policy estimates sparse attention patterns by:
|
|
||||||
1. Chunked QK computation using Triton kernels
|
|
||||||
2. Block-wise softmax with importance scores
|
|
||||||
3. Block selection based on threshold
|
|
||||||
4. Block sparse attention computation
|
|
||||||
|
|
||||||
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
|
|
||||||
"""
|
|
||||||
|
|
||||||
supports_prefill = True
|
|
||||||
supports_decode = False # XAttention is prefill-only
|
|
||||||
requires_block_selection = False # Only affects attention computation
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stride: int = 8,
|
|
||||||
threshold: float = 0.9,
|
|
||||||
chunk_size: Optional[int] = None,
|
|
||||||
use_triton: bool = True,
|
|
||||||
keep_sink: bool = False,
|
|
||||||
keep_recent: bool = False,
|
|
||||||
norm: float = 1.0,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize XAttention policy.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stride: Stride for reorganizing Q/K (default: 8)
|
|
||||||
threshold: Block selection threshold, 0-1 (default: 0.9)
|
|
||||||
chunk_size: Chunk size for estimation (auto if None)
|
|
||||||
use_triton: Use Triton kernels (requires SM 80+)
|
|
||||||
keep_sink: Always keep first block (sink tokens)
|
|
||||||
keep_recent: Always keep recent diagonal blocks
|
|
||||||
norm: Normalization factor for attention scores
|
|
||||||
"""
|
|
||||||
self.stride = stride
|
|
||||||
self.threshold = threshold
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
self.use_triton = use_triton
|
|
||||||
self.keep_sink = keep_sink
|
|
||||||
self.keep_recent = keep_recent
|
|
||||||
self.norm = norm
|
|
||||||
|
|
||||||
# Check Triton availability
|
|
||||||
if self.use_triton:
|
|
||||||
try:
|
|
||||||
import triton
|
|
||||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
||||||
if props.major < 8:
|
|
||||||
self.use_triton = False
|
|
||||||
print(f"XAttention: Triton requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.")
|
|
||||||
except ImportError:
|
|
||||||
self.use_triton = False
|
|
||||||
print("XAttention: Triton not available. Falling back to PyTorch.")
|
|
||||||
|
|
||||||
def select_blocks(
|
|
||||||
self,
|
|
||||||
available_blocks: List[int],
|
|
||||||
ctx: PolicyContext,
|
|
||||||
) -> List[int]:
|
|
||||||
"""
|
|
||||||
Select blocks for decode phase.
|
|
||||||
|
|
||||||
XAttention is prefill-only, so this method is only used as a fallback.
|
|
||||||
Returns all available blocks by default.
|
|
||||||
"""
|
|
||||||
# XAttention is prefill-only, but we need to implement this abstract method
|
|
||||||
# Since requires_block_selection=False, this won't be called for loading
|
|
||||||
return available_blocks
|
|
||||||
|
|
||||||
def sparse_prefill_attention(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute XAttention sparse attention for prefill.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Current transformer layer index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
seq_len = q.shape[0]
|
|
||||||
num_heads = q.shape[1]
|
|
||||||
head_dim = q.shape[2]
|
|
||||||
num_kv_heads = k.shape[1]
|
|
||||||
|
|
||||||
# Use FlashAttention directly for CPU offload mode
|
|
||||||
# FlashAttention supports GQA natively
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
|
|
||||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=seq_len,
|
|
||||||
max_seqlen_k=seq_len,
|
|
||||||
softmax_scale=1.0 / math.sqrt(head_dim),
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Fallback: PyTorch SDPA (supports GQA natively)
|
|
||||||
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
|
||||||
q, k, v,
|
|
||||||
attn_mask=None,
|
|
||||||
is_causal=True,
|
|
||||||
scale=1.0 / math.sqrt(head_dim)
|
|
||||||
)
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
def _xattn_offload_prefill(
|
|
||||||
self,
|
|
||||||
query_states: torch.Tensor,
|
|
||||||
key_states: torch.Tensor,
|
|
||||||
value_states: torch.Tensor,
|
|
||||||
causal: bool = True,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Simplified XAttention prefill for CPU offload mode.
|
|
||||||
|
|
||||||
Uses FlashAttention with full context since chunked estimation
|
|
||||||
with full key_states requires special handling.
|
|
||||||
"""
|
|
||||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
|
||||||
_, _, k_len, _ = key_states.shape
|
|
||||||
|
|
||||||
# Use FlashAttention with full context
|
|
||||||
# In offload mode, keys are already on CPU and loaded as needed
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
|
|
||||||
# Convert to [seq, heads, dim] format
|
|
||||||
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
|
|
||||||
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
|
||||||
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
|
||||||
|
|
||||||
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
|
||||||
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
|
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=q_len,
|
|
||||||
max_seqlen_k=k_len,
|
|
||||||
softmax_scale=1.0 / math.sqrt(head_dim),
|
|
||||||
causal=causal,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert back to [batch, seq, heads, dim]
|
|
||||||
attn_output = attn_output.unsqueeze(0).transpose(1, 2) # [1, q_len, num_heads, head_dim]
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Final fallback: PyTorch SDPA
|
|
||||||
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
|
||||||
query_states, key_states, value_states,
|
|
||||||
attn_mask=None,
|
|
||||||
is_causal=causal,
|
|
||||||
scale=1.0 / math.sqrt(head_dim)
|
|
||||||
)
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
def _xattn_prefill(
|
|
||||||
self,
|
|
||||||
query_states: torch.Tensor,
|
|
||||||
key_states: torch.Tensor,
|
|
||||||
value_states: torch.Tensor,
|
|
||||||
stride: int,
|
|
||||||
norm: float,
|
|
||||||
threshold: float,
|
|
||||||
block_size: int = 128,
|
|
||||||
use_triton: bool = True,
|
|
||||||
causal: bool = True,
|
|
||||||
chunk_size: Optional[int] = None,
|
|
||||||
keep_sink: bool = False,
|
|
||||||
keep_recent: bool = False,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
XAttention prefill implementation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_states: [batch, num_heads, q_len, head_dim]
|
|
||||||
key_states: [batch, num_heads, k_len, head_dim]
|
|
||||||
value_states: [batch, num_heads, k_len, head_dim]
|
|
||||||
... other params
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention output [batch, q_len, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
batch_size, num_heads, k_len, head_dim = key_states.shape
|
|
||||||
_, _, q_len, _ = query_states.shape
|
|
||||||
|
|
||||||
# Auto-compute chunk_size if not specified
|
|
||||||
if chunk_size is None:
|
|
||||||
chunk_size = int(
|
|
||||||
max(
|
|
||||||
min(
|
|
||||||
max(2048, 1 << (k_len - 1).bit_length()),
|
|
||||||
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()),
|
|
||||||
),
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Phase 1: Estimate sparse pattern
|
|
||||||
attn_sums, approx_simple_mask = self._xattn_estimate(
|
|
||||||
query_states,
|
|
||||||
key_states,
|
|
||||||
block_size=block_size,
|
|
||||||
stride=stride,
|
|
||||||
norm=norm,
|
|
||||||
threshold=threshold,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
use_triton=use_triton,
|
|
||||||
causal=causal,
|
|
||||||
keep_sink=keep_sink,
|
|
||||||
keep_recent=keep_recent,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Phase 2: Block sparse attention
|
|
||||||
# For now, use FlashAttention as fallback since block_sparse_attn_func may not be available
|
|
||||||
attn_output = self._block_sparse_attention_fallback(
|
|
||||||
query_states, key_states, value_states,
|
|
||||||
approx_simple_mask, block_size, q_len, k_len
|
|
||||||
)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
def _xattn_estimate(
|
|
||||||
self,
|
|
||||||
query_states: torch.Tensor,
|
|
||||||
key_states: torch.Tensor,
|
|
||||||
block_size: int,
|
|
||||||
stride: int,
|
|
||||||
norm: float = 1,
|
|
||||||
softmax: bool = True,
|
|
||||||
threshold: float = 0.9,
|
|
||||||
chunk_size: int = 16384,
|
|
||||||
use_triton: bool = True,
|
|
||||||
causal: bool = True,
|
|
||||||
keep_sink: bool = False,
|
|
||||||
keep_recent: bool = False,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Estimate sparse attention pattern using chunked computation.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
attn_sums: [batch, heads, q_blocks, k_blocks] - importance scores
|
|
||||||
simple_masks: [batch, heads, q_blocks, k_blocks] - boolean masks
|
|
||||||
"""
|
|
||||||
batch_size, num_kv_head, k_len, head_dim = key_states.shape
|
|
||||||
batch_size, num_q_head, q_len, head_dim = query_states.shape
|
|
||||||
|
|
||||||
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
|
|
||||||
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
|
|
||||||
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
|
|
||||||
k_block_num = (k_len + k_num_to_pad) // block_size
|
|
||||||
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
|
|
||||||
q_block_num = (q_len + q_num_to_pad) // block_size
|
|
||||||
|
|
||||||
# Pad inputs
|
|
||||||
if k_num_to_pad > 0:
|
|
||||||
pad_key_states = F.pad(key_states, (0, 0, 0, k_num_to_pad), value=0)
|
|
||||||
else:
|
|
||||||
pad_key_states = key_states
|
|
||||||
if q_num_to_pad > 0:
|
|
||||||
pad_query_states = F.pad(query_states, (0, 0, 0, q_num_to_pad), value=0)
|
|
||||||
else:
|
|
||||||
pad_query_states = query_states
|
|
||||||
|
|
||||||
reshaped_chunk_size = chunk_size // stride
|
|
||||||
reshaped_block_size = block_size // stride
|
|
||||||
k_reshaped_seq_len = (k_len + k_num_to_pad) // stride
|
|
||||||
|
|
||||||
attn_sum_list = []
|
|
||||||
simple_mask_list = []
|
|
||||||
|
|
||||||
for chunk_idx in range(q_chunk_num):
|
|
||||||
if use_triton:
|
|
||||||
# Triton GEMM + Softmax
|
|
||||||
attn_weights_slice = flat_group_gemm_fuse_reshape(
|
|
||||||
pad_query_states[:, :, (chunk_idx * reshaped_chunk_size) * stride : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size) * stride, :],
|
|
||||||
pad_key_states,
|
|
||||||
stride,
|
|
||||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
|
|
||||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
|
|
||||||
is_causal=causal,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_sum = softmax_fuse_block_sum(
|
|
||||||
attn_weights_slice,
|
|
||||||
reshaped_block_size,
|
|
||||||
min(4096, reshaped_block_size),
|
|
||||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
|
|
||||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
|
|
||||||
k_reshaped_seq_len - (k_num_to_pad // stride),
|
|
||||||
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
|
|
||||||
is_causal=causal,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# PyTorch fallback
|
|
||||||
chunk_size_actual = reshaped_chunk_size
|
|
||||||
chunk_start = chunk_idx * chunk_size_actual
|
|
||||||
chunk_end = chunk_start + chunk_size_actual
|
|
||||||
|
|
||||||
chunked_query = pad_query_states[:, :, chunk_start * stride:chunk_end * stride:stride, :]
|
|
||||||
attn_weights_slice = torch.matmul(chunked_query, pad_key_states.transpose(2, 3))
|
|
||||||
attn_weights_slice = attn_weights_slice / math.sqrt(head_dim) / stride / norm
|
|
||||||
|
|
||||||
if causal:
|
|
||||||
causal_mask = torch.zeros((batch_size, num_q_head, chunk_size_actual, chunk_size_actual * k_chunk_num), device=key_states.device)
|
|
||||||
causal_mask[:, :, :, -(k_num_to_pad // stride):] = float("-inf")
|
|
||||||
# ... more causal mask logic ...
|
|
||||||
attn_weights_slice = attn_weights_slice + causal_mask
|
|
||||||
|
|
||||||
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1, dtype=torch.float32)
|
|
||||||
attn_sum = attn_weights_slice.view(batch_size, num_q_head, chunk_size_actual // reshaped_block_size, reshaped_block_size, -1).sum(dim=-1).sum(dim=-2)
|
|
||||||
|
|
||||||
# Find blocks based on threshold
|
|
||||||
simple_mask = find_blocks_chunked(
|
|
||||||
attn_sum,
|
|
||||||
k_block_num - q_block_num + chunk_idx * (reshaped_chunk_size // reshaped_block_size),
|
|
||||||
threshold,
|
|
||||||
None,
|
|
||||||
decoding=False,
|
|
||||||
mode="prefill",
|
|
||||||
causal=causal,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_sum_list.append(attn_sum)
|
|
||||||
simple_mask_list.append(simple_mask)
|
|
||||||
|
|
||||||
attn_sums = torch.cat(attn_sum_list, dim=-2)
|
|
||||||
simple_masks = torch.cat(simple_mask_list, dim=-2)
|
|
||||||
|
|
||||||
# Apply causal mask to block masks
|
|
||||||
if causal:
|
|
||||||
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
|
||||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=key_states.device), diagonal=0),
|
|
||||||
simple_masks[:, :, -q_block_num:, -q_block_num:],
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if keep_sink:
|
|
||||||
simple_masks[:, :, 0, :] = True
|
|
||||||
|
|
||||||
if keep_recent:
|
|
||||||
eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool)
|
|
||||||
eye_matrix_expanded = eye_matrix.unsqueeze(0).unsqueeze(0).expand(1, num_q_head, q_block_num, q_block_num)
|
|
||||||
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
|
||||||
eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:]
|
|
||||||
)
|
|
||||||
|
|
||||||
return attn_sums, simple_masks
|
|
||||||
|
|
||||||
def _block_sparse_attention_fallback(
|
|
||||||
self,
|
|
||||||
query_states: torch.Tensor,
|
|
||||||
key_states: torch.Tensor,
|
|
||||||
value_states: torch.Tensor,
|
|
||||||
mask: torch.Tensor,
|
|
||||||
block_size: int,
|
|
||||||
q_len: int,
|
|
||||||
k_len: int,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Fallback implementation using FlashAttention.
|
|
||||||
|
|
||||||
Since block_sparse_attn_func may not be available in all environments,
|
|
||||||
this uses standard FlashAttention with full attention.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
|
|
||||||
batch_size, num_heads, _, head_dim = query_states.shape
|
|
||||||
|
|
||||||
# Convert to [seq, heads, dim] format
|
|
||||||
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
|
|
||||||
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
|
||||||
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
|
||||||
|
|
||||||
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
|
||||||
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
|
|
||||||
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=q_len,
|
|
||||||
max_seqlen_k=k_len,
|
|
||||||
softmax_scale=1.0 / math.sqrt(head_dim),
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert back to [batch, seq, heads, dim]
|
|
||||||
attn_output = attn_output.unsqueeze(0).transpose(1, 2)
|
|
||||||
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# Final fallback: PyTorch SDPA
|
|
||||||
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
|
||||||
attn_output = F.scaled_dot_product_attention(
|
|
||||||
query_states, key_states, value_states,
|
|
||||||
attn_mask=None,
|
|
||||||
is_causal=True,
|
|
||||||
scale=1.0 / math.sqrt(query_states.shape[-1])
|
|
||||||
)
|
|
||||||
return attn_output
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset policy state (no state to reset for XAttention)."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return (f"XAttentionPolicy("
|
|
||||||
f"stride={self.stride}, "
|
|
||||||
f"threshold={self.threshold}, "
|
|
||||||
f"use_triton={self.use_triton})")
|
|
||||||
@@ -1,8 +1,13 @@
|
|||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
import torch.cuda.nvtx
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
from nanovllm.utils.context import get_context
|
from nanovllm.utils.context import get_context
|
||||||
|
from nanovllm.kvcache.sparse.policy import PolicyContext
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def store_kvcache(
|
def store_kvcache(
|
||||||
@@ -55,17 +60,12 @@ def store_kvcache(
|
|||||||
valid_values_flat = valid_values.reshape(-1, D)
|
valid_values_flat = valid_values.reshape(-1, D)
|
||||||
|
|
||||||
# In-place scatter using index_copy_
|
# In-place scatter using index_copy_
|
||||||
|
# 即使 valid_slots 为空张量,index_copy_ 也是安全的(不会修改数据)。
|
||||||
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
|
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
|
||||||
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
|
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
"""
|
|
||||||
Attention layer for GPU-only mode.
|
|
||||||
|
|
||||||
For CPU offload mode, attention is computed directly in model_runner's
|
|
||||||
run_layerwise_offload_prefill/decode methods using FlashAttention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -87,29 +87,635 @@ class Attention(nn.Module):
|
|||||||
context = get_context()
|
context = get_context()
|
||||||
k_cache, v_cache = self.k_cache, self.v_cache
|
k_cache, v_cache = self.k_cache, self.v_cache
|
||||||
|
|
||||||
# Store KV to cache (for GPU-only mode)
|
# Determine if we're in chunked offload mode
|
||||||
if k_cache.numel() and v_cache.numel():
|
is_chunked_offload = (
|
||||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
context.is_chunked_prefill and
|
||||||
|
hasattr(context, 'kvcache_manager') and
|
||||||
|
context.kvcache_manager is not None and
|
||||||
|
hasattr(context.kvcache_manager, 'offload_engine')
|
||||||
|
)
|
||||||
|
|
||||||
|
#! Ensure synchronization before accessing k_cache/v_cache
|
||||||
|
# torch.cuda.synchronize()
|
||||||
|
#! =======================================================
|
||||||
|
|
||||||
|
if is_chunked_offload and context.is_prefill:
|
||||||
|
# Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot)
|
||||||
|
# This enables fully async offloads since each layer has its own buffer.
|
||||||
|
offload_engine = context.kvcache_manager.offload_engine
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||||
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
|
||||||
|
# k, v shape: [num_tokens, kv_heads, head_dim]
|
||||||
|
num_tokens = k.shape[0]
|
||||||
|
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
||||||
|
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
||||||
|
elif is_chunked_offload:
|
||||||
|
# Chunked decode mode: use compute_stream for store_kvcache
|
||||||
|
# This ensures proper synchronization with per-layer offload
|
||||||
|
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
||||||
|
if k_cache.numel() and v_cache.numel():
|
||||||
|
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||||
|
# slot_mapping is created with non_blocking=True on default stream, but we use it
|
||||||
|
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
|
||||||
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
|
else:
|
||||||
|
# Normal mode: store on default stream
|
||||||
|
if k_cache.numel() and v_cache.numel():
|
||||||
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
|
|
||||||
if context.is_prefill:
|
if context.is_prefill:
|
||||||
if context.block_tables is not None: # prefix cache
|
if context.is_chunked_prefill:
|
||||||
|
# Chunked prefill: merge attention from previous KV
|
||||||
|
o = self._chunked_prefill_attention(q, k, v, context)
|
||||||
|
elif context.block_tables is not None: # prefix cache
|
||||||
k, v = k_cache, v_cache
|
k, v = k_cache, v_cache
|
||||||
o = flash_attn_varlen_func(q, k, v,
|
o = flash_attn_varlen_func(q, k, v,
|
||||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||||
elif context.sparse_prefill_policy is not None:
|
|
||||||
# Sparse prefill (GPU-only) - delegate to policy
|
|
||||||
o = context.sparse_prefill_policy.sparse_prefill_attention(
|
|
||||||
q, k, v, self.layer_id
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
o = flash_attn_varlen_func(q, k, v,
|
o = flash_attn_varlen_func(q, k, v,
|
||||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||||
else: # decode
|
else: # decode
|
||||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
if context.is_chunked_prefill:
|
||||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
# Chunked decode: need to load all KV from CPU+GPU
|
||||||
softmax_scale=self.scale, causal=True)
|
# Store current decode token to per-layer decode buffer
|
||||||
|
# This is needed because GPU cache has no layer dimension,
|
||||||
|
# so all layers would overwrite each other in decode_slot.
|
||||||
|
kvcache_manager = context.kvcache_manager
|
||||||
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
pos_in_block = context.decode_pos_in_block
|
||||||
|
# k, v shape: [1, kv_heads, head_dim]
|
||||||
|
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
|
||||||
|
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
|
||||||
|
o = self._chunked_decode_attention(q, k, v, context)
|
||||||
|
else:
|
||||||
|
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||||||
|
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||||
|
softmax_scale=self.scale, causal=True)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
def _chunked_prefill_attention(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
context,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute attention with per-layer prefill buffer for async offload.
|
||||||
|
|
||||||
|
Optimized design:
|
||||||
|
- Current chunk's KV is written to per-layer prefill buffer (not GPU slot)
|
||||||
|
- Previous chunks' KV are loaded from CPU using GPU slots
|
||||||
|
- Each layer offloads from its own buffer - no waiting required!
|
||||||
|
|
||||||
|
For each layer:
|
||||||
|
1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model)
|
||||||
|
2. Load previous chunks from CPU using available slots (pipeline)
|
||||||
|
3. Compute attention against previous KV (no causal mask)
|
||||||
|
4. Compute attention against current KV from prefill buffer (causal)
|
||||||
|
5. Merge all results using online softmax
|
||||||
|
6. Async offload prefill buffer to CPU (no waiting!)
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
current_chunk_idx = context.current_chunk_idx
|
||||||
|
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
||||||
|
|
||||||
|
# q shape: [total_tokens, num_heads, head_dim]
|
||||||
|
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||||||
|
num_tokens = k.shape[0]
|
||||||
|
|
||||||
|
o_acc = None
|
||||||
|
lse_acc = None
|
||||||
|
|
||||||
|
kvcache_manager = context.kvcache_manager
|
||||||
|
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||||
|
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
|
||||||
|
|
||||||
|
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||||
|
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||||
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
|
||||||
|
# Apply sparse policy if enabled (Quest returns all blocks for prefill since query=None)
|
||||||
|
sparse_policy = kvcache_manager.sparse_policy
|
||||||
|
if cpu_block_table and sparse_policy is not None:
|
||||||
|
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
|
||||||
|
policy_ctx = PolicyContext(
|
||||||
|
query_chunk_idx=current_chunk_idx,
|
||||||
|
num_query_chunks=num_chunks,
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
query=None, # Prefill typically doesn't use query for selection
|
||||||
|
is_prefill=True,
|
||||||
|
block_size=kvcache_manager.block_size,
|
||||||
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
|
)
|
||||||
|
cpu_block_table = sparse_policy.select_blocks(
|
||||||
|
cpu_block_table, policy_ctx
|
||||||
|
)
|
||||||
|
|
||||||
|
if cpu_block_table:
|
||||||
|
# Get available load slots (all slots can be used since we use prefill buffer)
|
||||||
|
load_slots = list(range(offload_engine.num_ring_slots))
|
||||||
|
pipeline_depth = len(load_slots)
|
||||||
|
|
||||||
|
if pipeline_depth == 0:
|
||||||
|
# Only 1 slot total, cannot pipeline - use sync loading
|
||||||
|
o_acc, lse_acc = self._sync_load_previous_chunks(
|
||||||
|
q_batched, cpu_block_table, offload_engine
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use ring buffer pipeline
|
||||||
|
o_acc, lse_acc = self._ring_buffer_pipeline_load(
|
||||||
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||||
|
current_chunk_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get compute stream for all attention operations
|
||||||
|
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
|
||||||
|
|
||||||
|
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
|
||||||
|
if compute_stream is not None:
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||||
|
# Get KV from per-layer prefill buffer
|
||||||
|
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
|
||||||
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
|
q_batched,
|
||||||
|
k_batched,
|
||||||
|
v_batched,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
else:
|
||||||
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||||
|
k_batched = k.unsqueeze(0)
|
||||||
|
v_batched = v.unsqueeze(0)
|
||||||
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
|
q_batched,
|
||||||
|
k_batched,
|
||||||
|
v_batched,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
|
# Merge with accumulated (all on compute_stream for consistency)
|
||||||
|
if o_acc is None:
|
||||||
|
final_o = current_o
|
||||||
|
else:
|
||||||
|
if compute_stream is not None:
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||||
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
else:
|
||||||
|
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
|
||||||
|
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
|
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||||
|
|
||||||
|
# Per-layer ASYNC offload: offload prefill buffer to CPU
|
||||||
|
# No waiting required! Each layer has its own buffer and stream.
|
||||||
|
if offload_engine is not None and seq is not None:
|
||||||
|
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||||
|
if current_chunk_idx < len(cpu_block_ids):
|
||||||
|
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||||
|
# Async offload - no waiting, fully parallel across layers
|
||||||
|
offload_engine.offload_prefill_buffer_async(
|
||||||
|
self.layer_id, cpu_block_id, num_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sync default stream with compute_stream before returning
|
||||||
|
# This ensures the result is ready for the rest of the model (layernorm, MLP)
|
||||||
|
if compute_stream is not None:
|
||||||
|
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||||
|
|
||||||
|
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||||||
|
return final_o.squeeze(0)
|
||||||
|
|
||||||
|
def _sync_load_previous_chunks(
|
||||||
|
self,
|
||||||
|
q_batched: torch.Tensor,
|
||||||
|
cpu_block_table: list,
|
||||||
|
offload_engine,
|
||||||
|
):
|
||||||
|
"""Synchronous loading fallback when pipeline_depth=0."""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
o_acc, lse_acc = None, None
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
for block_idx, cpu_block_id in enumerate(cpu_block_table):
|
||||||
|
# Load to slot 0 (single slot)
|
||||||
|
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(0)
|
||||||
|
|
||||||
|
# IMPORTANT: Must use compute_stream to match wait_slot_layer
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(0)
|
||||||
|
|
||||||
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k, prev_v,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|
||||||
|
def _ring_buffer_pipeline_load(
|
||||||
|
self,
|
||||||
|
q_batched: torch.Tensor,
|
||||||
|
cpu_block_table: list,
|
||||||
|
load_slots: list,
|
||||||
|
offload_engine,
|
||||||
|
current_chunk_idx: int = -1,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Ring buffer async pipeline loading with double buffering.
|
||||||
|
|
||||||
|
Uses compute_done events to ensure safe buffer reuse:
|
||||||
|
- Before loading to slot X, wait for previous compute on slot X to finish
|
||||||
|
- Before computing on slot X, wait for load to slot X to finish
|
||||||
|
|
||||||
|
Timeline with 2 slots (A, B):
|
||||||
|
┌──────────────┐
|
||||||
|
│ Load B0→A │
|
||||||
|
└──────────────┘
|
||||||
|
┌──────────────┐ ┌──────────────┐
|
||||||
|
│ Load B1→B │ │ Load B2→A │ ...
|
||||||
|
└──────────────┘ └──────────────┘
|
||||||
|
↘ ↘
|
||||||
|
┌──────────────┐ ┌──────────────┐
|
||||||
|
│ Compute(A) │ │ Compute(B) │ ...
|
||||||
|
└──────────────┘ └──────────────┘
|
||||||
|
|
||||||
|
The load_to_slot_layer internally waits for compute_done[slot] before
|
||||||
|
starting the transfer, ensuring no data race.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
num_blocks = len(cpu_block_table)
|
||||||
|
if num_blocks == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
pipeline_depth = len(load_slots)
|
||||||
|
if pipeline_depth == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
o_acc, lse_acc = None, None
|
||||||
|
|
||||||
|
if pipeline_depth == 1:
|
||||||
|
# Only 1 slot available, cannot pipeline - use synchronous mode
|
||||||
|
# IMPORTANT: Must use compute_stream to match synchronization in
|
||||||
|
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
|
||||||
|
slot = load_slots[0]
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
cpu_block_id = cpu_block_table[block_idx]
|
||||||
|
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
|
||||||
|
offload_engine.wait_slot_layer(slot)
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Debug: call hooks on compute_stream (synchronized with transfer)
|
||||||
|
if offload_engine.debug_mode:
|
||||||
|
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
|
||||||
|
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
||||||
|
|
||||||
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k, prev_v,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
# Record compute done so next load can safely reuse this slot
|
||||||
|
offload_engine.record_slot_compute_done(slot)
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|
||||||
|
# N-way pipeline: use ALL available slots for maximum overlap
|
||||||
|
# Pipeline depth = num_slots - 1 (num_slots blocks in flight)
|
||||||
|
num_slots = len(load_slots)
|
||||||
|
|
||||||
|
# Phase 1: Pre-load up to num_slots blocks to fill the pipeline
|
||||||
|
# This starts all transfers in parallel, utilizing full PCIe bandwidth
|
||||||
|
num_preload = min(num_slots, num_blocks)
|
||||||
|
for i in range(num_preload):
|
||||||
|
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
|
||||||
|
|
||||||
|
# Phase 2: Main loop - compute and immediately reuse slot for next transfer
|
||||||
|
# Use dedicated compute_stream (not default stream) to enable overlap with transfers
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}")
|
||||||
|
|
||||||
|
# Cycle through slots: slot[block_idx % num_slots]
|
||||||
|
current_slot = load_slots[block_idx % num_slots]
|
||||||
|
cpu_block_id = cpu_block_table[block_idx]
|
||||||
|
|
||||||
|
# Wait for current slot's transfer to complete (on compute_stream)
|
||||||
|
offload_engine.wait_slot_layer(current_slot)
|
||||||
|
|
||||||
|
# Compute attention on current slot's data
|
||||||
|
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Debug: call hooks on compute_stream (synchronized with transfer)
|
||||||
|
if offload_engine.debug_mode:
|
||||||
|
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
|
||||||
|
|
||||||
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||||
|
|
||||||
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k, prev_v,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
|
# Record compute done - this allows the next transfer to safely overwrite this slot
|
||||||
|
offload_engine.record_slot_compute_done(current_slot)
|
||||||
|
|
||||||
|
# Immediately start loading the NEXT block into this slot (if more blocks remain)
|
||||||
|
# Key insight: reuse current_slot immediately after compute is done!
|
||||||
|
next_block_idx = block_idx + num_slots
|
||||||
|
if next_block_idx < num_blocks:
|
||||||
|
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
|
||||||
|
|
||||||
|
# Merge with accumulated (also on compute_stream for consistency)
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
|
torch.cuda.nvtx.range_pop() # PipelineBlock
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|
||||||
|
def _chunked_decode_attention(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
context,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute decode attention using cross-layer pipeline.
|
||||||
|
|
||||||
|
Optimization: Uses double-buffered layer cache to overlap H2D transfer
|
||||||
|
with computation across layers:
|
||||||
|
- Layer N computes while Layer N+1's data is being loaded
|
||||||
|
- Each layer only waits for its own data, not all layers' data
|
||||||
|
|
||||||
|
This reduces effective latency from O(num_layers * transfer_time) to
|
||||||
|
O(transfer_time + num_layers * compute_time) when transfer < compute.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||||||
|
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||||||
|
|
||||||
|
kvcache_manager = context.kvcache_manager
|
||||||
|
seq = context.chunked_seq
|
||||||
|
|
||||||
|
# Get only PREFILLED CPU blocks (exclude the current decode block)
|
||||||
|
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||||
|
if self.layer_id == 0:
|
||||||
|
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||||||
|
if not cpu_block_table:
|
||||||
|
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||||
|
|
||||||
|
# Calculate valid tokens in the last CPU block
|
||||||
|
# CRITICAL: Use original prefill length, not current seq length!
|
||||||
|
# CPU blocks are fixed after prefill, their content doesn't change during decode.
|
||||||
|
block_size = kvcache_manager.block_size
|
||||||
|
num_prefill_blocks = len(cpu_block_table)
|
||||||
|
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Original prefill length
|
||||||
|
last_block_valid_tokens = total_prefill_tokens % block_size
|
||||||
|
if last_block_valid_tokens == 0 and total_prefill_tokens > 0:
|
||||||
|
last_block_valid_tokens = block_size # Last block was exactly full
|
||||||
|
|
||||||
|
# Apply sparse policy if enabled (Quest does Top-K selection for decode)
|
||||||
|
sparse_policy = kvcache_manager.sparse_policy
|
||||||
|
if sparse_policy is not None:
|
||||||
|
policy_ctx = PolicyContext(
|
||||||
|
query_chunk_idx=0,
|
||||||
|
num_query_chunks=1,
|
||||||
|
layer_id=self.layer_id,
|
||||||
|
query=q_batched,
|
||||||
|
is_prefill=False,
|
||||||
|
block_size=kvcache_manager.block_size,
|
||||||
|
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||||||
|
)
|
||||||
|
cpu_block_table = sparse_policy.select_blocks(
|
||||||
|
cpu_block_table, policy_ctx
|
||||||
|
)
|
||||||
|
|
||||||
|
offload_engine = kvcache_manager.offload_engine
|
||||||
|
|
||||||
|
# Use cross-layer pipeline if active (initialized in model_runner)
|
||||||
|
if offload_engine.is_pipeline_active():
|
||||||
|
o_acc, lse_acc = self._decode_with_layer_pipeline(
|
||||||
|
q_batched, cpu_block_table, offload_engine,
|
||||||
|
block_size, last_block_valid_tokens
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback to original ring buffer pipeline
|
||||||
|
load_slots = offload_engine.decode_load_slots
|
||||||
|
o_acc, lse_acc = self._decode_ring_buffer_pipeline(
|
||||||
|
q_batched, cpu_block_table, load_slots, offload_engine,
|
||||||
|
block_size, last_block_valid_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now attend to accumulated decode tokens from per-layer decode buffer
|
||||||
|
pos_in_block = context.decode_pos_in_block
|
||||||
|
start_pos = context.decode_start_pos_in_block
|
||||||
|
num_accumulated = pos_in_block - start_pos + 1
|
||||||
|
|
||||||
|
# Sync compute_stream with default stream before reading decode_buffer
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
if num_accumulated > 0:
|
||||||
|
# Read from per-layer decode buffer
|
||||||
|
decode_k = offload_engine.decode_k_buffer[self.layer_id, start_pos:pos_in_block+1]
|
||||||
|
decode_v = offload_engine.decode_v_buffer[self.layer_id, start_pos:pos_in_block+1]
|
||||||
|
decode_k = decode_k.unsqueeze(0)
|
||||||
|
decode_v = decode_v.unsqueeze(0)
|
||||||
|
|
||||||
|
decode_o, decode_lse = flash_attn_with_lse(
|
||||||
|
q_batched, decode_k, decode_v,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc = decode_o
|
||||||
|
else:
|
||||||
|
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||||||
|
|
||||||
|
if o_acc is None:
|
||||||
|
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||||||
|
|
||||||
|
# Sync back to default stream before returning
|
||||||
|
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||||
|
|
||||||
|
return o_acc
|
||||||
|
|
||||||
|
def _decode_ring_buffer_pipeline(
|
||||||
|
self,
|
||||||
|
q_batched: torch.Tensor,
|
||||||
|
cpu_block_table: list,
|
||||||
|
load_slots: list,
|
||||||
|
offload_engine,
|
||||||
|
block_size: int,
|
||||||
|
last_block_valid_tokens: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Ring buffer pipeline for decode prefill loading (same mechanism as prefill).
|
||||||
|
|
||||||
|
Loads one block at a time, computes attention, and merges results.
|
||||||
|
Uses the same load_to_slot_layer / wait_slot_layer / get_kv_for_slot
|
||||||
|
methods as prefill for proven correctness.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
num_blocks = len(cpu_block_table)
|
||||||
|
if num_blocks == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
if not load_slots:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
o_acc, lse_acc = None, None
|
||||||
|
num_slots = len(load_slots)
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Phase 1: Pre-load up to num_slots blocks
|
||||||
|
num_preload = min(num_slots, num_blocks)
|
||||||
|
for i in range(num_preload):
|
||||||
|
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
|
||||||
|
|
||||||
|
# Phase 2: Process blocks with pipeline
|
||||||
|
for block_idx in range(num_blocks):
|
||||||
|
current_slot = load_slots[block_idx % num_slots]
|
||||||
|
cpu_block_id = cpu_block_table[block_idx]
|
||||||
|
|
||||||
|
# Wait for current slot's transfer to complete
|
||||||
|
offload_engine.wait_slot_layer(current_slot)
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Get KV from slot
|
||||||
|
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||||
|
|
||||||
|
# Handle partial last block
|
||||||
|
is_last_block = (block_idx == num_blocks - 1)
|
||||||
|
if is_last_block and last_block_valid_tokens < block_size:
|
||||||
|
prev_k = prev_k[:, :last_block_valid_tokens, :, :]
|
||||||
|
prev_v = prev_v[:, :last_block_valid_tokens, :, :]
|
||||||
|
|
||||||
|
# Compute attention
|
||||||
|
prev_o, prev_lse = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k, prev_v,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record compute done for slot reuse
|
||||||
|
offload_engine.record_slot_compute_done(current_slot)
|
||||||
|
|
||||||
|
# Start loading next block (pipeline)
|
||||||
|
next_block_idx = block_idx + num_slots
|
||||||
|
if next_block_idx < num_blocks:
|
||||||
|
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
|
||||||
|
|
||||||
|
# Merge with accumulated
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
if o_acc is None:
|
||||||
|
o_acc, lse_acc = prev_o, prev_lse
|
||||||
|
else:
|
||||||
|
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|
||||||
|
def _decode_with_layer_pipeline(
|
||||||
|
self,
|
||||||
|
q_batched: torch.Tensor,
|
||||||
|
cpu_block_table: list,
|
||||||
|
offload_engine,
|
||||||
|
block_size: int,
|
||||||
|
last_block_valid_tokens: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Decode using cross-layer pipeline for optimized H2D transfer.
|
||||||
|
|
||||||
|
This method uses pre-loaded layer buffers instead of loading
|
||||||
|
blocks one by one. The pipeline loads the next layer's data
|
||||||
|
while the current layer computes, achieving transfer/compute overlap.
|
||||||
|
|
||||||
|
The key insight is that each layer needs the SAME blocks but from
|
||||||
|
different layers of CPU cache. By double-buffering and pipelining
|
||||||
|
across layers, we reduce total latency.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
|
num_blocks = len(cpu_block_table)
|
||||||
|
if num_blocks == 0:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Get KV from pre-loaded layer buffer (triggers next layer loading)
|
||||||
|
prev_k, prev_v = offload_engine.get_decode_layer_kv(self.layer_id, num_blocks)
|
||||||
|
|
||||||
|
# prev_k, prev_v shape: [num_blocks, block_size, kv_heads, head_dim]
|
||||||
|
# Reshape to [1, num_blocks * block_size, kv_heads, head_dim]
|
||||||
|
total_tokens = num_blocks * block_size
|
||||||
|
|
||||||
|
# Handle partial last block
|
||||||
|
if last_block_valid_tokens < block_size:
|
||||||
|
# Only use valid tokens from last block
|
||||||
|
actual_tokens = (num_blocks - 1) * block_size + last_block_valid_tokens
|
||||||
|
# Flatten and truncate
|
||||||
|
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])[:actual_tokens]
|
||||||
|
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])[:actual_tokens]
|
||||||
|
else:
|
||||||
|
prev_k_flat = prev_k.reshape(-1, prev_k.shape[-2], prev_k.shape[-1])
|
||||||
|
prev_v_flat = prev_v.reshape(-1, prev_v.shape[-2], prev_v.shape[-1])
|
||||||
|
|
||||||
|
# Add batch dimension: [1, total_tokens, kv_heads, head_dim]
|
||||||
|
prev_k_batched = prev_k_flat.unsqueeze(0)
|
||||||
|
prev_v_batched = prev_v_flat.unsqueeze(0)
|
||||||
|
|
||||||
|
# Compute attention on all prefilled blocks at once
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
o_acc, lse_acc = flash_attn_with_lse(
|
||||||
|
q_batched, prev_k_batched, prev_v_batched,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return o_acc, lse_acc
|
||||||
|
|||||||
@@ -27,13 +27,13 @@ class RMSNorm(nn.Module):
|
|||||||
x = x.to(orig_dtype).mul_(self.weight)
|
x = x.to(orig_dtype).mul_(self.weight)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
def add_rms_forward(
|
def add_rms_forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
|
# Input MUST be 2D [N, D] to avoid recompilation due to rank mismatch
|
||||||
# Note: @torch.compile removed due to OOM with 64k sequences (memory fragmentation)
|
|
||||||
orig_dtype = x.dtype
|
orig_dtype = x.dtype
|
||||||
x = x.float().add_(residual.float())
|
x = x.float().add_(residual.float())
|
||||||
residual = x.to(orig_dtype)
|
residual = x.to(orig_dtype)
|
||||||
|
|||||||
@@ -3,13 +3,7 @@
|
|||||||
from nanovllm.models.registry import register_model, get_model_class, MODEL_REGISTRY
|
from nanovllm.models.registry import register_model, get_model_class, MODEL_REGISTRY
|
||||||
|
|
||||||
# Import models to trigger registration
|
# Import models to trigger registration
|
||||||
# Qwen3 requires transformers>=4.51.0 for Qwen3Config
|
from nanovllm.models import qwen3
|
||||||
try:
|
|
||||||
from nanovllm.models import qwen3
|
|
||||||
except ImportError as e:
|
|
||||||
import warnings
|
|
||||||
warnings.warn(f"Qwen3 model not available (requires transformers>=4.51.0): {e}")
|
|
||||||
|
|
||||||
from nanovllm.models import llama
|
from nanovllm.models import llama
|
||||||
|
|
||||||
__all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"]
|
__all__ = ["register_model", "get_model_class", "MODEL_REGISTRY"]
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Optional, List, Tuple, Any
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@@ -14,9 +14,26 @@ class Context:
|
|||||||
context_lens: torch.Tensor | None = None
|
context_lens: torch.Tensor | None = None
|
||||||
block_tables: torch.Tensor | None = None
|
block_tables: torch.Tensor | None = None
|
||||||
|
|
||||||
# Sparse prefill attention support (GPU-only path)
|
# Chunked prefill support
|
||||||
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
|
is_chunked_prefill: bool = False
|
||||||
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
|
# Previous KV chunks info: List of (start_pos, end_pos) for blocks on CPU
|
||||||
|
prev_kv_ranges: List[Tuple[int, int]] = field(default_factory=list)
|
||||||
|
# Current chunk's position offset (for causal mask)
|
||||||
|
chunk_offset: int = 0
|
||||||
|
# Reference to kvcache manager for loading previous KV (HybridKVCacheManager)
|
||||||
|
kvcache_manager: Any = None
|
||||||
|
# Current layer's previous K/V chunks (loaded from CPU)
|
||||||
|
# Set by model_runner before each layer's forward
|
||||||
|
prev_kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]] = field(default_factory=list)
|
||||||
|
# Current sequence being processed (for chunked prefill to load KV)
|
||||||
|
chunked_seq: Any = None
|
||||||
|
# Position within block for decode (used for reading from Decode region)
|
||||||
|
decode_pos_in_block: int = 0
|
||||||
|
# Starting position within block where decode tokens began (for accumulated token tracking)
|
||||||
|
# Used when batching decode offloads - we need to attend to all accumulated tokens
|
||||||
|
decode_start_pos_in_block: int = 0
|
||||||
|
# Current chunk index for ring buffer pipeline (prefill only)
|
||||||
|
current_chunk_idx: int = 0
|
||||||
|
|
||||||
|
|
||||||
_CONTEXT = Context()
|
_CONTEXT = Context()
|
||||||
@@ -35,7 +52,14 @@ def set_context(
|
|||||||
slot_mapping=None,
|
slot_mapping=None,
|
||||||
context_lens=None,
|
context_lens=None,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
sparse_prefill_policy=None,
|
is_chunked_prefill=False,
|
||||||
|
prev_kv_ranges=None,
|
||||||
|
chunk_offset=0,
|
||||||
|
kvcache_manager=None,
|
||||||
|
chunked_seq=None,
|
||||||
|
decode_pos_in_block=0,
|
||||||
|
decode_start_pos_in_block=0,
|
||||||
|
current_chunk_idx=0,
|
||||||
):
|
):
|
||||||
global _CONTEXT
|
global _CONTEXT
|
||||||
_CONTEXT = Context(
|
_CONTEXT = Context(
|
||||||
@@ -47,7 +71,14 @@ def set_context(
|
|||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
context_lens=context_lens,
|
context_lens=context_lens,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
sparse_prefill_policy=sparse_prefill_policy,
|
is_chunked_prefill=is_chunked_prefill,
|
||||||
|
prev_kv_ranges=prev_kv_ranges or [],
|
||||||
|
chunk_offset=chunk_offset,
|
||||||
|
kvcache_manager=kvcache_manager,
|
||||||
|
chunked_seq=chunked_seq,
|
||||||
|
decode_pos_in_block=decode_pos_in_block,
|
||||||
|
decode_start_pos_in_block=decode_start_pos_in_block,
|
||||||
|
current_chunk_idx=current_chunk_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
76
progress.md
Normal file
76
progress.md
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
# Progress Log: Multi-Model Support
|
||||||
|
|
||||||
|
## Session: 2026-01-10
|
||||||
|
|
||||||
|
### Initial Analysis Complete
|
||||||
|
|
||||||
|
**Time**: Session start
|
||||||
|
|
||||||
|
**Actions:**
|
||||||
|
1. Read `nanovllm/engine/model_runner.py` - 确认硬编码位置 (line 35)
|
||||||
|
2. Read `nanovllm/models/qwen3.py` - 理解 Qwen3 模型结构
|
||||||
|
3. Read `nanovllm/utils/loader.py` - 理解权重加载机制
|
||||||
|
4. Read `nanovllm/layers/rotary_embedding.py` - 发现 RoPE scaling 限制
|
||||||
|
5. Read `/home/zijie/models/Llama-3.1-8B-Instruct/config.json` - 理解 Llama 配置
|
||||||
|
|
||||||
|
**Key Findings:**
|
||||||
|
- 模型加载在 `model_runner.py:35` 硬编码为 Qwen3
|
||||||
|
- RoPE 目前不支持 scaling (`assert rope_scaling is None`)
|
||||||
|
- Llama 3.1 需要 "llama3" 类型的 RoPE scaling
|
||||||
|
- Llama 无 q_norm/k_norm,无 attention bias
|
||||||
|
|
||||||
|
**Created:**
|
||||||
|
- `task_plan.md` - 6 阶段实施计划
|
||||||
|
- `findings.md` - 技术分析和发现
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Phase Status
|
||||||
|
|
||||||
|
| Phase | Status | Notes |
|
||||||
|
|-------|--------|-------|
|
||||||
|
| 1. Model Registry | **COMPLETED** | `registry.py`, `__init__.py` |
|
||||||
|
| 2. Llama3 RoPE | **COMPLETED** | `rotary_embedding.py` |
|
||||||
|
| 3. Llama Model | **COMPLETED** | `llama.py` |
|
||||||
|
| 4. ModelRunner | **COMPLETED** | Dynamic loading |
|
||||||
|
| 5. Qwen3 Register | **COMPLETED** | `@register_model` decorator |
|
||||||
|
| 6. Testing | **COMPLETED** | Both Llama & Qwen3 pass |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Test Results
|
||||||
|
|
||||||
|
### Llama 3.1-8B-Instruct (32K needle, GPU 0, offload)
|
||||||
|
```
|
||||||
|
Input: 32768 tokens
|
||||||
|
Expected: 7492
|
||||||
|
Output: 7492
|
||||||
|
Status: PASSED
|
||||||
|
Prefill: 1644 tok/s
|
||||||
|
```
|
||||||
|
|
||||||
|
### Qwen3-4B (8K needle, GPU 1, offload) - Regression Test
|
||||||
|
```
|
||||||
|
Input: 8192 tokens
|
||||||
|
Expected: 7492
|
||||||
|
Output: 7492
|
||||||
|
Status: PASSED
|
||||||
|
Prefill: 3295 tok/s
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Files Modified This Session
|
||||||
|
|
||||||
|
| File | Action | Description |
|
||||||
|
|------|--------|-------------|
|
||||||
|
| `nanovllm/models/registry.py` | created | Model registry with `@register_model` decorator |
|
||||||
|
| `nanovllm/models/__init__.py` | created | Export registry functions, import models |
|
||||||
|
| `nanovllm/models/llama.py` | created | Llama model implementation |
|
||||||
|
| `nanovllm/models/qwen3.py` | modified | Added `@register_model` decorator |
|
||||||
|
| `nanovllm/layers/rotary_embedding.py` | modified | Added Llama3 RoPE scaling |
|
||||||
|
| `nanovllm/engine/model_runner.py` | modified | Dynamic model loading via registry |
|
||||||
|
| `.claude/rules/gpu-testing.md` | created | GPU testing rules |
|
||||||
|
| `task_plan.md` | created | Implementation plan |
|
||||||
|
| `findings.md` | created | Technical findings |
|
||||||
|
| `progress.md` | created | Progress tracking |
|
||||||
144
task_plan.md
Normal file
144
task_plan.md
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
# Task Plan: Multi-Model Support for nanovllm
|
||||||
|
|
||||||
|
## Goal
|
||||||
|
扩展 nanovllm 框架以支持多种模型(当前只支持 Qwen3),特别是添加 Llama-3.1-8B-Instruct 支持,并建立可扩展的模型添加范式。
|
||||||
|
|
||||||
|
## Current State Analysis
|
||||||
|
|
||||||
|
### 硬编码问题位置
|
||||||
|
- `nanovllm/engine/model_runner.py:35`: 直接实例化 `Qwen3ForCausalLM(hf_config)`
|
||||||
|
- `nanovllm/engine/model_runner.py:9`: 硬编码导入 `from nanovllm.models.qwen3 import Qwen3ForCausalLM`
|
||||||
|
|
||||||
|
### Qwen3 vs Llama 3.1 架构差异
|
||||||
|
|
||||||
|
| Feature | Qwen3 | Llama 3.1 |
|
||||||
|
|---------|-------|-----------|
|
||||||
|
| Config Class | Qwen3Config | LlamaConfig |
|
||||||
|
| attention_bias | True (可配置) | False |
|
||||||
|
| q_norm/k_norm | 有 (when bias=False) | 无 |
|
||||||
|
| mlp_bias | N/A | False |
|
||||||
|
| RoPE Scaling | None (目前) | llama3 类型 |
|
||||||
|
| RoPE theta | 1000000 | 500000 |
|
||||||
|
| hidden_act | silu | silu |
|
||||||
|
| tie_word_embeddings | True | False |
|
||||||
|
|
||||||
|
### 关键限制
|
||||||
|
- `rotary_embedding.py:59`: `assert rope_scaling is None` - 不支持 RoPE scaling
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Phases
|
||||||
|
|
||||||
|
### Phase 1: Create Model Registry Pattern [pending]
|
||||||
|
**Files to modify:**
|
||||||
|
- `nanovllm/models/__init__.py` (new)
|
||||||
|
- `nanovllm/models/registry.py` (new)
|
||||||
|
|
||||||
|
**Tasks:**
|
||||||
|
1. 创建模型注册表机制
|
||||||
|
2. 定义模型注册装饰器 `@register_model`
|
||||||
|
3. 实现 `get_model_class(hf_config)` 函数,根据 `architectures` 字段自动选择模型
|
||||||
|
|
||||||
|
**Design:**
|
||||||
|
```python
|
||||||
|
MODEL_REGISTRY: dict[str, type] = {}
|
||||||
|
|
||||||
|
def register_model(*architectures):
|
||||||
|
"""Decorator to register a model class for given architecture names."""
|
||||||
|
def decorator(cls):
|
||||||
|
for arch in architectures:
|
||||||
|
MODEL_REGISTRY[arch] = cls
|
||||||
|
return cls
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def get_model_class(hf_config) -> type:
|
||||||
|
"""Get model class based on HF config architectures."""
|
||||||
|
for arch in hf_config.architectures:
|
||||||
|
if arch in MODEL_REGISTRY:
|
||||||
|
return MODEL_REGISTRY[arch]
|
||||||
|
raise ValueError(f"Unsupported architecture: {hf_config.architectures}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 2: Add Llama3 RoPE Scaling Support [pending]
|
||||||
|
**Files to modify:**
|
||||||
|
- `nanovllm/layers/rotary_embedding.py`
|
||||||
|
|
||||||
|
**Tasks:**
|
||||||
|
1. 实现 `Llama3RotaryEmbedding` 类,支持 llama3 rope_type
|
||||||
|
2. 修改 `get_rope()` 函数,根据 rope_scaling 类型选择实现
|
||||||
|
3. 保持向后兼容(rope_scaling=None 使用原实现)
|
||||||
|
|
||||||
|
**Llama3 RoPE Scaling Formula:**
|
||||||
|
```python
|
||||||
|
# From transformers:
|
||||||
|
# low_freq_factor, high_freq_factor, original_max_position_embeddings
|
||||||
|
# Adjust frequencies based on wavelength thresholds
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 3: Implement Llama Model [pending]
|
||||||
|
**Files to create:**
|
||||||
|
- `nanovllm/models/llama.py`
|
||||||
|
|
||||||
|
**Tasks:**
|
||||||
|
1. 创建 `LlamaAttention` 类(无 q_norm/k_norm,无 QKV bias)
|
||||||
|
2. 创建 `LlamaMLP` 类(与 Qwen3MLP 类似,无 bias)
|
||||||
|
3. 创建 `LlamaDecoderLayer` 类
|
||||||
|
4. 创建 `LlamaModel` 和 `LlamaForCausalLM` 类
|
||||||
|
5. 添加 `packed_modules_mapping` 以支持权重加载
|
||||||
|
6. 使用 `@register_model("LlamaForCausalLM")` 注册
|
||||||
|
|
||||||
|
### Phase 4: Modify ModelRunner for Dynamic Loading [pending]
|
||||||
|
**Files to modify:**
|
||||||
|
- `nanovllm/engine/model_runner.py`
|
||||||
|
|
||||||
|
**Tasks:**
|
||||||
|
1. 移除硬编码 `from nanovllm.models.qwen3 import Qwen3ForCausalLM`
|
||||||
|
2. 导入 `from nanovllm.models import get_model_class`
|
||||||
|
3. 替换 `self.model = Qwen3ForCausalLM(hf_config)` 为:
|
||||||
|
```python
|
||||||
|
model_class = get_model_class(hf_config)
|
||||||
|
self.model = model_class(hf_config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Phase 5: Register Qwen3 Model [pending]
|
||||||
|
**Files to modify:**
|
||||||
|
- `nanovllm/models/qwen3.py`
|
||||||
|
|
||||||
|
**Tasks:**
|
||||||
|
1. 导入 `from nanovllm.models.registry import register_model`
|
||||||
|
2. 添加 `@register_model("Qwen3ForCausalLM", "Qwen2ForCausalLM")` 装饰器
|
||||||
|
|
||||||
|
### Phase 6: Test with Llama-3.1-8B-Instruct [pending]
|
||||||
|
**Files:**
|
||||||
|
- `tests/test_needle.py` (existing, use for validation)
|
||||||
|
|
||||||
|
**Tasks:**
|
||||||
|
1. 运行 needle 测试: `python tests/test_needle.py --model ~/models/Llama-3.1-8B-Instruct`
|
||||||
|
2. 验证模型加载正确
|
||||||
|
3. 验证推理输出正确
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Errors Encountered
|
||||||
|
| Error | Attempt | Resolution |
|
||||||
|
|-------|---------|------------|
|
||||||
|
| (none yet) | | |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Success Criteria
|
||||||
|
- [x] 分析完成:理解当前架构和需要的改动
|
||||||
|
- [ ] Phase 1: 模型注册表实现
|
||||||
|
- [ ] Phase 2: Llama3 RoPE scaling 支持
|
||||||
|
- [ ] Phase 3: Llama 模型实现
|
||||||
|
- [ ] Phase 4: ModelRunner 动态加载
|
||||||
|
- [ ] Phase 5: Qwen3 模型注册
|
||||||
|
- [ ] Phase 6: Llama needle 测试通过
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
- 保持现有 Qwen3 功能不变
|
||||||
|
- 遵循现有代码风格
|
||||||
|
- 复用现有 layers 组件(Linear, RMSNorm, Embedding 等)
|
||||||
|
- 只添加必要的代码,不过度工程化
|
||||||
@@ -1,112 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
# Run NIAH tests in parallel on 6 GPUs
|
|
||||||
# This tests the dynamic port allocation fix
|
|
||||||
|
|
||||||
set -e
|
|
||||||
|
|
||||||
MODEL="${1:-/home/zijie/models/Llama-3.1-8B-Instruct}"
|
|
||||||
PROJECT_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
|
||||||
|
|
||||||
echo "=========================================="
|
|
||||||
echo "Parallel NIAH Test on 6 GPUs"
|
|
||||||
echo "=========================================="
|
|
||||||
echo "Model: $MODEL"
|
|
||||||
echo "Project: $PROJECT_ROOT"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Sample distribution (100 samples total):
|
|
||||||
# GPU 0: 0-16 (17 samples)
|
|
||||||
# GPU 1: 17-33 (17 samples)
|
|
||||||
# GPU 2: 34-50 (17 samples)
|
|
||||||
# GPU 3: 51-67 (17 samples)
|
|
||||||
# GPU 4: 68-83 (16 samples)
|
|
||||||
# GPU 5: 84-99 (16 samples)
|
|
||||||
|
|
||||||
declare -a RANGES=("0-16" "17-33" "34-50" "51-67" "68-83" "84-99")
|
|
||||||
declare -a PIDS=()
|
|
||||||
|
|
||||||
# Create log directory
|
|
||||||
LOG_DIR="$PROJECT_ROOT/logs"
|
|
||||||
mkdir -p "$LOG_DIR"
|
|
||||||
|
|
||||||
# Start all 6 processes
|
|
||||||
for gpu in {0..5}; do
|
|
||||||
range="${RANGES[$gpu]}"
|
|
||||||
log_file="$LOG_DIR/gpu${gpu}_${range}.log"
|
|
||||||
|
|
||||||
echo "Starting GPU $gpu: samples $range -> $log_file"
|
|
||||||
|
|
||||||
CUDA_VISIBLE_DEVICES=$gpu PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" \
|
|
||||||
python "$PROJECT_ROOT/tests/test_ruler_niah.py" \
|
|
||||||
--model "$MODEL" \
|
|
||||||
--sample-indices "$range" \
|
|
||||||
--enable-offload \
|
|
||||||
--num-gpu-blocks 4 \
|
|
||||||
--quiet \
|
|
||||||
> "$log_file" 2>&1 &
|
|
||||||
|
|
||||||
PIDS+=($!)
|
|
||||||
|
|
||||||
# Small delay to stagger starts
|
|
||||||
sleep 2
|
|
||||||
done
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "All 6 processes started. Waiting for completion..."
|
|
||||||
echo "PIDs: ${PIDS[*]}"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Wait for all processes and collect results
|
|
||||||
declare -a RESULTS=()
|
|
||||||
ALL_PASSED=true
|
|
||||||
|
|
||||||
for i in {0..5}; do
|
|
||||||
pid="${PIDS[$i]}"
|
|
||||||
range="${RANGES[$i]}"
|
|
||||||
log_file="$LOG_DIR/gpu${i}_${range}.log"
|
|
||||||
|
|
||||||
if wait $pid; then
|
|
||||||
RESULTS+=("GPU $i ($range): PASSED")
|
|
||||||
echo "GPU $i completed successfully"
|
|
||||||
else
|
|
||||||
RESULTS+=("GPU $i ($range): FAILED (exit code $?)")
|
|
||||||
ALL_PASSED=false
|
|
||||||
echo "GPU $i FAILED!"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "=========================================="
|
|
||||||
echo "RESULTS SUMMARY"
|
|
||||||
echo "=========================================="
|
|
||||||
for result in "${RESULTS[@]}"; do
|
|
||||||
echo "$result"
|
|
||||||
done
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Show accuracy from each log
|
|
||||||
echo "Accuracy per GPU:"
|
|
||||||
for i in {0..5}; do
|
|
||||||
range="${RANGES[$i]}"
|
|
||||||
log_file="$LOG_DIR/gpu${i}_${range}.log"
|
|
||||||
if [ -f "$log_file" ]; then
|
|
||||||
accuracy=$(grep -E "Accuracy:|accuracy" "$log_file" | tail -1 || echo "N/A")
|
|
||||||
port=$(grep "Auto-assigned distributed port" "$log_file" | head -1 || echo "N/A")
|
|
||||||
echo " GPU $i ($range): $accuracy | $port"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
if $ALL_PASSED; then
|
|
||||||
echo "=========================================="
|
|
||||||
echo "ALL 6 TESTS PASSED!"
|
|
||||||
echo "Dynamic port allocation works correctly."
|
|
||||||
echo "=========================================="
|
|
||||||
exit 0
|
|
||||||
else
|
|
||||||
echo "=========================================="
|
|
||||||
echo "SOME TESTS FAILED!"
|
|
||||||
echo "Check logs in $LOG_DIR"
|
|
||||||
echo "=========================================="
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
@@ -1,163 +0,0 @@
|
|||||||
"""
|
|
||||||
Needle-in-haystack test with MInference sparse attention.
|
|
||||||
|
|
||||||
Tests: MInference sparse prefill on GPU-only path (no CPU offload).
|
|
||||||
This validates that MInference's vertical + slash sparse pattern can
|
|
||||||
correctly retrieve information from long context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from nanovllm.config import SparsePolicyType
|
|
||||||
from utils import generate_needle_prompt, check_needle_answer
|
|
||||||
|
|
||||||
|
|
||||||
def run_minference_test(
|
|
||||||
model_path: str,
|
|
||||||
max_model_len: int = 16384,
|
|
||||||
input_len: int = 8192,
|
|
||||||
needle_position: float = 0.5,
|
|
||||||
needle_value: str = "7492",
|
|
||||||
adaptive_budget: float = 0.3,
|
|
||||||
max_new_tokens: int = 32,
|
|
||||||
verbose: bool = True,
|
|
||||||
) -> bool:
|
|
||||||
"""
|
|
||||||
Run needle test with MInference sparse prefill attention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to model
|
|
||||||
max_model_len: Maximum model context length
|
|
||||||
input_len: Target input sequence length
|
|
||||||
needle_position: Where to place needle (0.0-1.0)
|
|
||||||
needle_value: The secret value to find
|
|
||||||
adaptive_budget: MInference budget as fraction of seq_len
|
|
||||||
max_new_tokens: Maximum tokens to generate
|
|
||||||
verbose: Print detailed output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if test passed, False otherwise
|
|
||||||
"""
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"MInference Sparse Prefill Test (GPU-only)")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Max model len: {max_model_len}")
|
|
||||||
print(f"Input length: {input_len}")
|
|
||||||
print(f"Needle position: {needle_position:.0%}")
|
|
||||||
print(f"Needle value: {needle_value}")
|
|
||||||
print(f"Adaptive budget: {adaptive_budget}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# Initialize LLM with MInference sparse attention
|
|
||||||
llm = LLM(
|
|
||||||
model_path,
|
|
||||||
enforce_eager=True,
|
|
||||||
max_model_len=max_model_len,
|
|
||||||
max_num_batched_tokens=max_model_len,
|
|
||||||
enable_cpu_offload=False, # GPU-only
|
|
||||||
sparse_policy=SparsePolicyType.MINFERENCE,
|
|
||||||
minference_adaptive_budget=adaptive_budget,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate needle prompt
|
|
||||||
prompt, expected = generate_needle_prompt(
|
|
||||||
tokenizer=llm.tokenizer,
|
|
||||||
target_length=input_len,
|
|
||||||
needle_position=needle_position,
|
|
||||||
needle_value=needle_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Generate output
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.6,
|
|
||||||
max_tokens=max_new_tokens,
|
|
||||||
)
|
|
||||||
outputs = llm.generate([prompt], sampling_params, use_tqdm=True)
|
|
||||||
|
|
||||||
# Check result
|
|
||||||
output_text = outputs[0]["text"]
|
|
||||||
output_token_ids = outputs[0]["token_ids"]
|
|
||||||
passed = check_needle_answer(output_text, expected)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Result")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Expected: {expected}")
|
|
||||||
print(f"Output tokens ({len(output_token_ids)}): {output_token_ids[:20]}")
|
|
||||||
print(f"Output: {output_text[:200]}...")
|
|
||||||
print(f"Status: {'PASSED' if passed else 'FAILED'}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
return passed
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Needle-in-haystack test with MInference sparse prefill"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", "-m",
|
|
||||||
type=str,
|
|
||||||
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
|
|
||||||
help="Path to model"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-model-len",
|
|
||||||
type=int,
|
|
||||||
default=16 * 1024,
|
|
||||||
help="Maximum model context length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--input-len",
|
|
||||||
type=int,
|
|
||||||
default=8 * 1024,
|
|
||||||
help="Target input sequence length"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--needle-position",
|
|
||||||
type=float,
|
|
||||||
default=0.5,
|
|
||||||
help="Needle position (0.0=start, 0.5=middle, 1.0=end)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--needle-value",
|
|
||||||
type=str,
|
|
||||||
default="7492",
|
|
||||||
help="The secret value to hide"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--adaptive-budget",
|
|
||||||
type=float,
|
|
||||||
default=0.3,
|
|
||||||
help="MInference adaptive budget (fraction of seq_len)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-new-tokens",
|
|
||||||
type=int,
|
|
||||||
default=32,
|
|
||||||
help="Maximum tokens to generate"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
passed = run_minference_test(
|
|
||||||
model_path=args.model,
|
|
||||||
max_model_len=args.max_model_len,
|
|
||||||
input_len=args.input_len,
|
|
||||||
needle_position=args.needle_position,
|
|
||||||
needle_value=args.needle_value,
|
|
||||||
adaptive_budget=args.adaptive_budget,
|
|
||||||
max_new_tokens=args.max_new_tokens,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if passed:
|
|
||||||
print("test_minference_gpu: PASSED")
|
|
||||||
else:
|
|
||||||
print("test_minference_gpu: FAILED")
|
|
||||||
exit(1)
|
|
||||||
@@ -31,14 +31,8 @@ def run_needle_test(
|
|||||||
max_new_tokens: int = 32,
|
max_new_tokens: int = 32,
|
||||||
enable_cpu_offload: bool = False,
|
enable_cpu_offload: bool = False,
|
||||||
enable_quest: bool = False,
|
enable_quest: bool = False,
|
||||||
enable_minference: bool = False,
|
|
||||||
sparse_topk: int = 8,
|
sparse_topk: int = 8,
|
||||||
sparse_threshold: int = 4,
|
sparse_threshold: int = 4,
|
||||||
minference_budget: float = 0.3,
|
|
||||||
minference_vertical: int = 1000,
|
|
||||||
minference_slash: int = 6096,
|
|
||||||
gpu_utilization: float = 0.9,
|
|
||||||
enforce_eager: bool = True,
|
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
@@ -55,25 +49,14 @@ def run_needle_test(
|
|||||||
max_new_tokens: Maximum tokens to generate
|
max_new_tokens: Maximum tokens to generate
|
||||||
enable_cpu_offload: Enable CPU offload mode
|
enable_cpu_offload: Enable CPU offload mode
|
||||||
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
||||||
enable_minference: Enable MInference sparse prefill (GPU-only)
|
|
||||||
sparse_topk: Top-K blocks for Quest
|
sparse_topk: Top-K blocks for Quest
|
||||||
sparse_threshold: Apply sparse only when blocks > threshold
|
sparse_threshold: Apply sparse only when blocks > threshold
|
||||||
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
|
|
||||||
minference_vertical: Fixed vertical_size (only used when budget=None)
|
|
||||||
minference_slash: Fixed slash_size (only used when budget=None)
|
|
||||||
gpu_utilization: GPU memory utilization fraction
|
|
||||||
verbose: Print detailed output
|
verbose: Print detailed output
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if test passed, False otherwise
|
True if test passed, False otherwise
|
||||||
"""
|
"""
|
||||||
# Determine sparse policy
|
sparse_policy = SparsePolicyType.QUEST if enable_quest else SparsePolicyType.FULL
|
||||||
if enable_minference:
|
|
||||||
sparse_policy = SparsePolicyType.MINFERENCE
|
|
||||||
elif enable_quest:
|
|
||||||
sparse_policy = SparsePolicyType.QUEST
|
|
||||||
else:
|
|
||||||
sparse_policy = SparsePolicyType.FULL
|
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f"\n{'='*60}")
|
print(f"\n{'='*60}")
|
||||||
@@ -86,40 +69,24 @@ def run_needle_test(
|
|||||||
print(f"Needle position: {needle_position:.0%}")
|
print(f"Needle position: {needle_position:.0%}")
|
||||||
print(f"Needle value: {needle_value}")
|
print(f"Needle value: {needle_value}")
|
||||||
print(f"CPU offload: {enable_cpu_offload}")
|
print(f"CPU offload: {enable_cpu_offload}")
|
||||||
print(f"Sparse policy: {sparse_policy.name}")
|
if enable_cpu_offload:
|
||||||
if enable_cpu_offload and enable_quest:
|
print(f"Sparse policy: {sparse_policy.name} (topk={sparse_topk}, threshold={sparse_threshold})")
|
||||||
print(f" Quest: topk={sparse_topk}, threshold={sparse_threshold}")
|
|
||||||
if enable_minference:
|
|
||||||
if minference_budget is not None:
|
|
||||||
print(f" MInference: adaptive (budget={minference_budget})")
|
|
||||||
else:
|
|
||||||
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
|
|
||||||
print(f"{'='*60}\n")
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
# 1. Initialize LLM
|
# 1. Initialize LLM
|
||||||
llm_kwargs = {
|
llm_kwargs = {
|
||||||
"enforce_eager": enforce_eager,
|
"enforce_eager": True,
|
||||||
"max_model_len": max_model_len,
|
"max_model_len": max_model_len,
|
||||||
"max_num_batched_tokens": max_model_len,
|
"max_num_batched_tokens": max_model_len,
|
||||||
"enable_cpu_offload": enable_cpu_offload,
|
"enable_cpu_offload": enable_cpu_offload,
|
||||||
"kvcache_block_size": block_size,
|
"kvcache_block_size": block_size,
|
||||||
"gpu_memory_utilization": gpu_utilization,
|
|
||||||
}
|
}
|
||||||
if enable_cpu_offload:
|
if enable_cpu_offload:
|
||||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
||||||
|
llm_kwargs["sparse_policy"] = sparse_policy
|
||||||
llm_kwargs["sparse_topk_blocks"] = sparse_topk
|
llm_kwargs["sparse_topk_blocks"] = sparse_topk
|
||||||
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
||||||
|
|
||||||
# Set sparse policy (can be used with or without offload)
|
|
||||||
if enable_minference or enable_quest:
|
|
||||||
llm_kwargs["sparse_policy"] = sparse_policy
|
|
||||||
|
|
||||||
# MInference params (works with both GPU-only and offload mode)
|
|
||||||
if enable_minference:
|
|
||||||
llm_kwargs["minference_adaptive_budget"] = minference_budget
|
|
||||||
llm_kwargs["minference_vertical_size"] = minference_vertical
|
|
||||||
llm_kwargs["minference_slash_size"] = minference_slash
|
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
llm = LLM(model_path, **llm_kwargs)
|
||||||
|
|
||||||
# 2. Generate needle prompt
|
# 2. Generate needle prompt
|
||||||
@@ -219,11 +186,6 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable Quest sparse attention (decode-only Top-K selection)"
|
help="Enable Quest sparse attention (decode-only Top-K selection)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--enable-minference",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sparse-topk",
|
"--sparse-topk",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -236,49 +198,8 @@ if __name__ == "__main__":
|
|||||||
default=4,
|
default=4,
|
||||||
help="Apply sparse only when blocks > threshold"
|
help="Apply sparse only when blocks > threshold"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--minference-budget",
|
|
||||||
type=float,
|
|
||||||
default=0.3,
|
|
||||||
help="MInference adaptive budget (fraction of seq_len, 0.3=30%% compute, 0=fixed mode)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--minference-vertical",
|
|
||||||
type=int,
|
|
||||||
default=1000,
|
|
||||||
help="Fixed vertical_size (only used when budget=0)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--minference-slash",
|
|
||||||
type=int,
|
|
||||||
default=6096,
|
|
||||||
help="Fixed slash_size (only used when budget=0)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--gpu-utilization",
|
|
||||||
type=float,
|
|
||||||
default=0.9,
|
|
||||||
help="GPU memory utilization (default: 0.9)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enforce-eager",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Force eager execution (disable CUDA graphs)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-cuda-graph",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CUDA graph (disable enforce_eager)"
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Convert budget=0 to None for fixed mode
|
|
||||||
minference_budget = args.minference_budget if args.minference_budget > 0 else None
|
|
||||||
|
|
||||||
# Determine enforce_eager: use_cuda_graph overrides enforce_eager
|
|
||||||
enforce_eager = not args.use_cuda_graph
|
|
||||||
|
|
||||||
passed = run_needle_test(
|
passed = run_needle_test(
|
||||||
model_path=args.model,
|
model_path=args.model,
|
||||||
max_model_len=args.max_model_len,
|
max_model_len=args.max_model_len,
|
||||||
@@ -290,14 +211,8 @@ if __name__ == "__main__":
|
|||||||
max_new_tokens=args.max_new_tokens,
|
max_new_tokens=args.max_new_tokens,
|
||||||
enable_cpu_offload=args.enable_offload,
|
enable_cpu_offload=args.enable_offload,
|
||||||
enable_quest=args.enable_quest,
|
enable_quest=args.enable_quest,
|
||||||
enable_minference=args.enable_minference,
|
|
||||||
sparse_topk=args.sparse_topk,
|
sparse_topk=args.sparse_topk,
|
||||||
sparse_threshold=args.sparse_threshold,
|
sparse_threshold=args.sparse_threshold,
|
||||||
minference_budget=minference_budget,
|
|
||||||
minference_vertical=args.minference_vertical,
|
|
||||||
minference_slash=args.minference_slash,
|
|
||||||
gpu_utilization=args.gpu_utilization,
|
|
||||||
enforce_eager=enforce_eager,
|
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,841 +0,0 @@
|
|||||||
"""
|
|
||||||
OffloadedTensor 统一测试套件
|
|
||||||
|
|
||||||
本文件整合了 OffloadedTensor 的所有测试,包括:
|
|
||||||
1. 基础功能验证
|
|
||||||
2. Chunked GEMM 测试
|
|
||||||
3. 同步分析
|
|
||||||
|
|
||||||
核心组件:
|
|
||||||
- OffloadedTensor: 虚拟 GPU Tensor,支持透明 CPU/GPU 数据移动
|
|
||||||
- OffloadManager: LRU 缓存管理,支持同步/异步传输
|
|
||||||
- ChunkedOffloadLinear: 沿着 seqlen 维度分块的 Linear 层
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import weakref
|
|
||||||
import threading
|
|
||||||
import time
|
|
||||||
from typing import Optional, Dict, List, Tuple, Any
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Part 1: 核心组件
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
class OffloadedTensor(torch.Tensor):
|
|
||||||
"""
|
|
||||||
虚拟 GPU Tensor:假装在 GPU 上,实际可能在 CPU
|
|
||||||
|
|
||||||
所有计算操作通过 __torch_dispatch__ 拦截,
|
|
||||||
在计算前自动加载数据到 GPU。
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def __new__(cls, real_tensor: torch.Tensor, manager: 'OffloadManager', tensor_id: int):
|
|
||||||
device = torch.device("cuda", torch.cuda.current_device())
|
|
||||||
ret = torch.Tensor._make_wrapper_subclass(
|
|
||||||
cls,
|
|
||||||
real_tensor.size(),
|
|
||||||
strides=real_tensor.stride(),
|
|
||||||
dtype=real_tensor.dtype,
|
|
||||||
device=device,
|
|
||||||
requires_grad=real_tensor.requires_grad
|
|
||||||
)
|
|
||||||
ret._real_tensor = real_tensor
|
|
||||||
ret._manager = weakref.ref(manager)
|
|
||||||
ret._tensor_id = tensor_id
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def __init__(self, real_tensor: torch.Tensor, manager: 'OffloadManager', tensor_id: int):
|
|
||||||
self._real_tensor = real_tensor
|
|
||||||
self._manager = weakref.ref(manager)
|
|
||||||
self._tensor_id = tensor_id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
"""永远返回 CUDA device,欺骗 PyTorch 的检查"""
|
|
||||||
return torch.device("cuda", torch.cuda.current_device())
|
|
||||||
|
|
||||||
def to(self, *args, **kwargs):
|
|
||||||
"""拦截 .to() 调用"""
|
|
||||||
device = None
|
|
||||||
if args and isinstance(args[0], torch.device):
|
|
||||||
device = args[0]
|
|
||||||
elif 'device' in kwargs:
|
|
||||||
device = kwargs['device']
|
|
||||||
|
|
||||||
if device and device.type == "cuda":
|
|
||||||
return self
|
|
||||||
return super().to(*args, **kwargs)
|
|
||||||
|
|
||||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
||||||
"""拦截所有 PyTorch 操作,自动加载数据"""
|
|
||||||
kwargs = kwargs or {}
|
|
||||||
|
|
||||||
manager = self._manager()
|
|
||||||
if manager:
|
|
||||||
manager.stats['dispatch_count'] += 1
|
|
||||||
|
|
||||||
# 特殊处理:detach 返回 self
|
|
||||||
func_name = getattr(func, 'name', '')
|
|
||||||
if isinstance(func_name, str) and 'detach' in func_name.lower():
|
|
||||||
return self
|
|
||||||
|
|
||||||
# 解包 OffloadedTensor 为真实 tensor
|
|
||||||
def unwrap(t):
|
|
||||||
if isinstance(t, OffloadedTensor):
|
|
||||||
mgr = t._manager()
|
|
||||||
if mgr:
|
|
||||||
return mgr.get_gpu_tensor(t._real_tensor, t._tensor_id)
|
|
||||||
return t._real_tensor.cuda()
|
|
||||||
return t
|
|
||||||
|
|
||||||
new_args = torch.utils._pytree.tree_map(unwrap, args)
|
|
||||||
new_kwargs = torch.utils._pytree.tree_map(unwrap, kwargs)
|
|
||||||
|
|
||||||
result = func(*new_args, **new_kwargs)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class OffloadManager:
|
|
||||||
"""
|
|
||||||
管理 tensor 的卸载和预取
|
|
||||||
|
|
||||||
特性:
|
|
||||||
- LRU 缓存管理 GPU 上的张量
|
|
||||||
- 支持同步/异步传输模式
|
|
||||||
- 完整的性能统计
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
device: str = "cuda",
|
|
||||||
offload_device: str = "cpu",
|
|
||||||
max_gpu_tensors: int = 2,
|
|
||||||
non_blocking: bool = False,
|
|
||||||
):
|
|
||||||
self.device = torch.device(device)
|
|
||||||
self.offload_device = torch.device(offload_device)
|
|
||||||
self._gpu_pool: Dict[int, torch.Tensor] = {}
|
|
||||||
self._cpu_storage: Dict[int, torch.Tensor] = {}
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
self._tensor_id_counter = 0
|
|
||||||
self._max_gpu_tensors = max_gpu_tensors
|
|
||||||
self._access_order: List[int] = []
|
|
||||||
self.non_blocking = non_blocking
|
|
||||||
|
|
||||||
# 统计信息
|
|
||||||
self.stats = {
|
|
||||||
'load_count': 0,
|
|
||||||
'evict_count': 0,
|
|
||||||
'dispatch_count': 0,
|
|
||||||
'transfer_times_ms': [],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _next_id(self) -> int:
|
|
||||||
tid = self._tensor_id_counter
|
|
||||||
self._tensor_id_counter += 1
|
|
||||||
return tid
|
|
||||||
|
|
||||||
def wrap(self, tensor: torch.Tensor) -> OffloadedTensor:
|
|
||||||
"""包装 tensor 为虚拟 GPU tensor"""
|
|
||||||
if isinstance(tensor, OffloadedTensor):
|
|
||||||
return tensor
|
|
||||||
|
|
||||||
tensor_id = self._next_id()
|
|
||||||
cpu_tensor = tensor.detach().to(self.offload_device)
|
|
||||||
self._cpu_storage[tensor_id] = cpu_tensor
|
|
||||||
|
|
||||||
return OffloadedTensor(cpu_tensor, self, tensor_id)
|
|
||||||
|
|
||||||
def get_gpu_tensor(self, real_tensor: torch.Tensor, tensor_id: int) -> torch.Tensor:
|
|
||||||
"""获取 GPU 上的数据(LRU 缓存)"""
|
|
||||||
with self._lock:
|
|
||||||
self.stats['load_count'] += 1
|
|
||||||
|
|
||||||
if tensor_id in self._gpu_pool:
|
|
||||||
# 已在 GPU 上,更新 LRU
|
|
||||||
if tensor_id in self._access_order:
|
|
||||||
self._access_order.remove(tensor_id)
|
|
||||||
self._access_order.append(tensor_id)
|
|
||||||
return self._gpu_pool[tensor_id]
|
|
||||||
|
|
||||||
# LRU 驱逐
|
|
||||||
while len(self._gpu_pool) >= self._max_gpu_tensors:
|
|
||||||
if self._access_order:
|
|
||||||
evict_id = self._access_order.pop(0)
|
|
||||||
if evict_id in self._gpu_pool:
|
|
||||||
del self._gpu_pool[evict_id]
|
|
||||||
self.stats['evict_count'] += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
|
|
||||||
# 加载到 GPU
|
|
||||||
cpu_tensor = self._cpu_storage.get(tensor_id, real_tensor)
|
|
||||||
gpu_tensor = cpu_tensor.to(self.device, non_blocking=self.non_blocking)
|
|
||||||
self._gpu_pool[tensor_id] = gpu_tensor
|
|
||||||
self._access_order.append(tensor_id)
|
|
||||||
|
|
||||||
return gpu_tensor
|
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
|
||||||
"""获取统计信息"""
|
|
||||||
transfer_times = self.stats['transfer_times_ms']
|
|
||||||
return {
|
|
||||||
'load_count': self.stats['load_count'],
|
|
||||||
'evict_count': self.stats['evict_count'],
|
|
||||||
'dispatch_count': self.stats['dispatch_count'],
|
|
||||||
'gpu_pool_size': len(self._gpu_pool),
|
|
||||||
'total_tensors': len(self._cpu_storage),
|
|
||||||
'total_transfer_time_ms': sum(transfer_times),
|
|
||||||
'avg_transfer_time_ms': sum(transfer_times) / len(transfer_times) if transfer_times else 0,
|
|
||||||
'transfer_times_ms': list(transfer_times),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class OffloadModuleWrapper(nn.Module):
|
|
||||||
"""包装 nn.Module,实现参数级别的卸载"""
|
|
||||||
|
|
||||||
def __init__(self, module: nn.Module, manager: OffloadManager):
|
|
||||||
super().__init__()
|
|
||||||
self._original_module = module
|
|
||||||
self._manager = manager
|
|
||||||
self._wrap_parameters(module, "")
|
|
||||||
|
|
||||||
def _wrap_parameters(self, module: nn.Module, prefix: str):
|
|
||||||
"""递归包装模块的所有参数"""
|
|
||||||
for name, param in list(module.named_parameters(recurse=False)):
|
|
||||||
param.requires_grad_(False)
|
|
||||||
wrapped = self._manager.wrap(param.data)
|
|
||||||
delattr(module, name)
|
|
||||||
setattr(module, name, wrapped)
|
|
||||||
|
|
||||||
for child_name, child in list(module.named_children()):
|
|
||||||
self._wrap_parameters(child, prefix + child_name + ".")
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
|
||||||
return self._original_module(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Part 2: 高级模块
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
class ChunkedOffloadLinear(nn.Module):
|
|
||||||
"""
|
|
||||||
沿着 seqlen 维度分块的 Linear 层
|
|
||||||
|
|
||||||
将输入 [seqlen, in_features] 分成多个 chunks,每个 chunk 独立进行 GEMM 计算。
|
|
||||||
weight 使用 OffloadedTensor,按需加载到 GPU。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
in_features: 输入特征维度
|
|
||||||
out_features: 输出特征维度
|
|
||||||
chunk_size: 每个 chunk 的大小
|
|
||||||
max_gpu_tensors: GPU 上最多缓存的 tensor 数量
|
|
||||||
non_blocking: 是否使用异步传输
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
in_features: int,
|
|
||||||
out_features: int,
|
|
||||||
chunk_size: int = 4096,
|
|
||||||
max_gpu_tensors: int = 2,
|
|
||||||
non_blocking: bool = False,
|
|
||||||
bias: bool = False,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.in_features = in_features
|
|
||||||
self.out_features = out_features
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
|
|
||||||
self.manager = OffloadManager(
|
|
||||||
max_gpu_tensors=max_gpu_tensors,
|
|
||||||
non_blocking=non_blocking
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_tensor = torch.empty(out_features, in_features, dtype=torch.float16)
|
|
||||||
nn.init.xavier_uniform_(weight_tensor)
|
|
||||||
weight_tensor.requires_grad_(False)
|
|
||||||
|
|
||||||
self.weight = self.manager.wrap(weight_tensor)
|
|
||||||
self.bias = None
|
|
||||||
if bias:
|
|
||||||
self.bias = nn.Parameter(torch.empty(out_features))
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
||||||
seqlen = x.shape[0]
|
|
||||||
|
|
||||||
if seqlen <= self.chunk_size:
|
|
||||||
return self._compute_chunk(x)
|
|
||||||
|
|
||||||
outputs = []
|
|
||||||
for start_idx in range(0, seqlen, self.chunk_size):
|
|
||||||
end_idx = min(start_idx + self.chunk_size, seqlen)
|
|
||||||
chunk = x[start_idx:end_idx]
|
|
||||||
chunk_output = self._compute_chunk(chunk)
|
|
||||||
outputs.append(chunk_output)
|
|
||||||
|
|
||||||
return torch.cat(outputs, dim=0)
|
|
||||||
|
|
||||||
def _compute_chunk(self, chunk: torch.Tensor) -> torch.Tensor:
|
|
||||||
return torch.nn.functional.linear(chunk, self.weight, self.bias)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# 辅助函数
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def calculate_memory(
|
|
||||||
seqlen: int,
|
|
||||||
in_features: int,
|
|
||||||
out_features: int,
|
|
||||||
dtype: torch.dtype = torch.float16,
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
"""计算显存占用(MB)"""
|
|
||||||
element_size = torch.finfo(dtype).bits / 8
|
|
||||||
|
|
||||||
activation = seqlen * in_features * element_size / (1024 ** 2)
|
|
||||||
weight = in_features * out_features * element_size / (1024 ** 2)
|
|
||||||
output = seqlen * out_features * element_size / (1024 ** 2)
|
|
||||||
|
|
||||||
total = activation + weight + output
|
|
||||||
peak = max(activation, output) + weight
|
|
||||||
|
|
||||||
return {
|
|
||||||
'activation_mb': activation,
|
|
||||||
'weight_mb': weight,
|
|
||||||
'output_mb': output,
|
|
||||||
'total_mb': total,
|
|
||||||
'peak_mb': peak,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def run_benchmark(
|
|
||||||
layer: nn.Module,
|
|
||||||
input_tensor: torch.Tensor,
|
|
||||||
num_runs: int = 3,
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
"""运行性能测试"""
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Warmup
|
|
||||||
with torch.no_grad():
|
|
||||||
_ = layer(input_tensor)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
# Benchmark
|
|
||||||
start_time = time.time()
|
|
||||||
for _ in range(num_runs):
|
|
||||||
with torch.no_grad():
|
|
||||||
output = layer(input_tensor)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
avg_time = elapsed / num_runs
|
|
||||||
|
|
||||||
total_elements = input_tensor.numel() + output.numel()
|
|
||||||
throughput = total_elements / avg_time / 1e6
|
|
||||||
|
|
||||||
return {
|
|
||||||
'avg_time_ms': avg_time * 1000,
|
|
||||||
'throughput_meps': throughput,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Part 3: 测试套件 - 功能测试
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def test_1_basic_offloaded_tensor():
|
|
||||||
"""测试 OffloadedTensor 基本功能"""
|
|
||||||
print("\n=== Test 1: Basic OffloadedTensor ===")
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
manager = OffloadManager(max_gpu_tensors=2)
|
|
||||||
|
|
||||||
t1 = torch.randn(4, 4)
|
|
||||||
t2 = torch.randn(4, 4)
|
|
||||||
t3 = torch.randn(4, 4)
|
|
||||||
|
|
||||||
w1 = manager.wrap(t1)
|
|
||||||
w2 = manager.wrap(t2)
|
|
||||||
w3 = manager.wrap(t3)
|
|
||||||
|
|
||||||
print(f"✓ Created OffloadedTensors")
|
|
||||||
print(f" w1.device: {w1.device}")
|
|
||||||
print(f" w2.device: {w2.device}")
|
|
||||||
|
|
||||||
assert w1.device.type == "cuda"
|
|
||||||
print(f"✓ is_cuda check passed")
|
|
||||||
|
|
||||||
result = w1 + w2
|
|
||||||
print(f"✓ Addition works: {result.shape}")
|
|
||||||
|
|
||||||
stats = manager.get_stats()
|
|
||||||
print(f"✓ Manager stats: {stats}")
|
|
||||||
print("PASSED\n")
|
|
||||||
|
|
||||||
|
|
||||||
def test_2_mlp_with_offload():
|
|
||||||
"""测试 MLP 模型使用 OffloadedTensor"""
|
|
||||||
print("\n=== Test 2: MLP with OffloadedTensor ===")
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
class SimpleMLP(nn.Module):
|
|
||||||
def __init__(self, hidden_size=128, intermediate_size=256):
|
|
||||||
super().__init__()
|
|
||||||
self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
|
|
||||||
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
|
||||||
return self.down_proj(nn.functional.silu(gate) * up)
|
|
||||||
|
|
||||||
hidden_size = 128
|
|
||||||
intermediate_size = 256
|
|
||||||
batch_size, seq_len = 2, 4
|
|
||||||
|
|
||||||
input_ids = torch.randn(batch_size, seq_len, hidden_size, device="cuda")
|
|
||||||
|
|
||||||
model_original = SimpleMLP(hidden_size, intermediate_size)
|
|
||||||
model_original.to("cuda")
|
|
||||||
model_original.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
expected = model_original(input_ids)
|
|
||||||
|
|
||||||
state_dict = model_original.state_dict()
|
|
||||||
|
|
||||||
model = SimpleMLP(hidden_size, intermediate_size)
|
|
||||||
model.load_state_dict(state_dict)
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
offloaded_model, manager = apply_offload_to_model(model, max_gpu_tensors=2)
|
|
||||||
offloaded_model.eval()
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
output = offloaded_model(input_ids)
|
|
||||||
|
|
||||||
print(f"✓ Forward pass completed: {output.shape}")
|
|
||||||
|
|
||||||
stats = manager.get_stats()
|
|
||||||
print(f"✓ Offload stats: {stats}")
|
|
||||||
|
|
||||||
diff = (output - expected).abs().max().item()
|
|
||||||
print(f"✓ Output correctness: max diff = {diff:.6f}")
|
|
||||||
|
|
||||||
assert diff < 1e-5
|
|
||||||
print("PASSED\n")
|
|
||||||
|
|
||||||
|
|
||||||
def apply_offload_to_model(model: nn.Module, max_gpu_tensors: int = 2):
|
|
||||||
"""应用卸载到模型的所有参数"""
|
|
||||||
manager = OffloadManager(max_gpu_tensors=max_gpu_tensors)
|
|
||||||
wrapper = OffloadModuleWrapper(model, manager)
|
|
||||||
return wrapper, manager
|
|
||||||
|
|
||||||
|
|
||||||
def test_3_lru_eviction():
|
|
||||||
"""测试 LRU 驱逐机制"""
|
|
||||||
print("\n=== Test 3: LRU Eviction ===")
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
manager = OffloadManager(max_gpu_tensors=2)
|
|
||||||
|
|
||||||
tensors = [torch.randn(2, 2) for _ in range(4)]
|
|
||||||
wrapped = [manager.wrap(t) for t in tensors]
|
|
||||||
|
|
||||||
print(f"✓ Created {len(wrapped)} OffloadedTensors")
|
|
||||||
print(f" GPU pool capacity: {manager._max_gpu_tensors}")
|
|
||||||
|
|
||||||
_ = wrapped[0] + wrapped[1]
|
|
||||||
stats = manager.get_stats()
|
|
||||||
print(f"✓ After accessing t1, t2: GPU pool = {stats['gpu_pool_size']}")
|
|
||||||
|
|
||||||
_ = wrapped[2] + wrapped[2]
|
|
||||||
stats = manager.get_stats()
|
|
||||||
print(f"✓ After accessing t3: GPU pool = {stats['gpu_pool_size']}, evicted = {stats['evict_count']}")
|
|
||||||
|
|
||||||
_ = wrapped[3] + wrapped[3]
|
|
||||||
stats = manager.get_stats()
|
|
||||||
print(f"✓ After accessing t4: GPU pool = {stats['gpu_pool_size']}, evicted = {stats['evict_count']}")
|
|
||||||
|
|
||||||
assert stats['evict_count'] >= 1
|
|
||||||
print("PASSED\n")
|
|
||||||
|
|
||||||
|
|
||||||
def test_4_correctness():
|
|
||||||
"""测试输出正确性"""
|
|
||||||
print("\n=== Test 4: Correctness Check ===")
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
in_features = 512
|
|
||||||
out_features = 1024
|
|
||||||
seqlen = 4096
|
|
||||||
chunk_size = 1024
|
|
||||||
|
|
||||||
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
|
|
||||||
|
|
||||||
# 创建标准层并保存权重
|
|
||||||
linear = nn.Linear(in_features, out_features, bias=False)
|
|
||||||
linear.to("cuda", dtype=torch.float16)
|
|
||||||
linear.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
expected = linear(x)
|
|
||||||
|
|
||||||
print(f"✓ Got expected output")
|
|
||||||
|
|
||||||
# 创建 ChunkedOffloadLinear,使用相同的权重
|
|
||||||
chunked_layer = ChunkedOffloadLinear(in_features, out_features, chunk_size, max_gpu_tensors=2)
|
|
||||||
|
|
||||||
# 复制权重到 chunked_layer
|
|
||||||
with torch.no_grad():
|
|
||||||
weight_data = linear.weight.data.cpu()
|
|
||||||
chunked_layer.manager._cpu_storage[0] = weight_data
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
actual = chunked_layer(x)
|
|
||||||
|
|
||||||
print(f"✓ Got actual output")
|
|
||||||
|
|
||||||
diff = (actual - expected).abs().max().item()
|
|
||||||
print(f"✓ Max difference: {diff:.6f}")
|
|
||||||
|
|
||||||
assert diff < 1e-5
|
|
||||||
print("PASSED\n")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Part 3: 测试套件 - 性能测试
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def test_5_memory_analysis():
|
|
||||||
"""分析内存占用"""
|
|
||||||
print("\n=== Test 5: Memory Analysis ===")
|
|
||||||
|
|
||||||
in_features = 4096
|
|
||||||
out_features = 12244
|
|
||||||
chunk_size = 4096
|
|
||||||
|
|
||||||
seqlens = [4096, 16384, 65536, 131072]
|
|
||||||
|
|
||||||
print(f"\nMemory Analysis (in={in_features}, out={out_features}, chunk={chunk_size}):")
|
|
||||||
print(f"{'Seqlen':>10} | {'Activation':>12} | {'Weight':>12} | {'Output':>12} | {'Peak':>12} | {'Chunked':>12}")
|
|
||||||
print("-" * 90)
|
|
||||||
|
|
||||||
for seqlen in seqlens:
|
|
||||||
full = calculate_memory(seqlen, in_features, out_features)
|
|
||||||
chunked = calculate_memory(chunk_size, in_features, out_features)
|
|
||||||
|
|
||||||
print(f"{seqlen:>10} | "
|
|
||||||
f"{full['activation_mb']:>10.1f}MB | "
|
|
||||||
f"{full['weight_mb']:>10.1f}MB | "
|
|
||||||
f"{full['output_mb']:>10.1f}MB | "
|
|
||||||
f"{full['peak_mb']:>10.1f}MB | "
|
|
||||||
f"{chunked['peak_mb']:>10.1f}MB")
|
|
||||||
|
|
||||||
print("\n✓ Chunked offload 显存占用恒定,与序列长度无关!")
|
|
||||||
print("PASSED\n")
|
|
||||||
|
|
||||||
|
|
||||||
def test_6_long_sequence():
|
|
||||||
"""测试超长序列"""
|
|
||||||
print("\n=== Test 6: Long Sequence (128K tokens) ===")
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
in_features = 4096
|
|
||||||
out_features = 12244
|
|
||||||
seqlen = 128 * 1024
|
|
||||||
chunk_size = 4096
|
|
||||||
|
|
||||||
full = calculate_memory(seqlen, in_features, out_features)
|
|
||||||
chunked = calculate_memory(chunk_size, in_features, out_features)
|
|
||||||
|
|
||||||
print(f"Memory Comparison:")
|
|
||||||
print(f" Full: {full['peak_mb']:.1f} MB")
|
|
||||||
print(f" Chunked: {chunked['peak_mb']:.1f} MB")
|
|
||||||
print(f" Savings: {(1 - chunked['peak_mb']/full['peak_mb'])*100:.1f}%")
|
|
||||||
|
|
||||||
layer = ChunkedOffloadLinear(in_features, out_features, chunk_size, max_gpu_tensors=1)
|
|
||||||
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
start = time.time()
|
|
||||||
output = layer(x)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
elapsed = (time.time() - start) * 1000
|
|
||||||
|
|
||||||
print(f"✓ Forward pass: {output.shape}")
|
|
||||||
print(f" Time: {elapsed:.1f} ms")
|
|
||||||
print(f" Throughput: {seqlen/elapsed/1e3:.1f}K tokens/sec")
|
|
||||||
|
|
||||||
stats = layer.manager.get_stats()
|
|
||||||
print(f"✓ Chunks processed: {seqlen // chunk_size}")
|
|
||||||
print(f"✓ Load count: {stats['load_count']}")
|
|
||||||
print("PASSED\n")
|
|
||||||
|
|
||||||
|
|
||||||
def test_7_performance_comparison():
|
|
||||||
"""性能对比测试"""
|
|
||||||
print("\n=== Test 7: Performance Comparison ===")
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
in_features = 4096
|
|
||||||
out_features = 12244
|
|
||||||
seqlen = 16384
|
|
||||||
chunk_size = 4096
|
|
||||||
|
|
||||||
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
|
|
||||||
|
|
||||||
linear = nn.Linear(in_features, out_features, bias=False).cuda().half().eval()
|
|
||||||
standard_stats = run_benchmark(linear, x, num_runs=5)
|
|
||||||
print(f"✓ Standard Linear: {standard_stats['avg_time_ms']:.1f} ms")
|
|
||||||
|
|
||||||
chunked_layer = ChunkedOffloadLinear(in_features, out_features, chunk_size, max_gpu_tensors=1)
|
|
||||||
chunked_stats = run_benchmark(chunked_layer, x, num_runs=5)
|
|
||||||
print(f"✓ ChunkedOffloadLinear: {chunked_stats['avg_time_ms']:.1f} ms")
|
|
||||||
|
|
||||||
speedup = standard_stats['avg_time_ms'] / chunked_stats['avg_time_ms']
|
|
||||||
print(f"✓ Speedup: {speedup:.2f}x")
|
|
||||||
print("PASSED\n")
|
|
||||||
|
|
||||||
|
|
||||||
def test_8_transformers_layer():
|
|
||||||
"""测试实际 transformers 权重"""
|
|
||||||
print("\n=== Test 8: Transformers Layer Test ===")
|
|
||||||
|
|
||||||
try:
|
|
||||||
from transformers import AutoModelForCausalLM
|
|
||||||
except ImportError:
|
|
||||||
print("transformers not installed, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
|
|
||||||
|
|
||||||
try:
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
torch_dtype=torch.float16,
|
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
|
||||||
model.eval()
|
|
||||||
model.to("cuda")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Failed to load model: {e}")
|
|
||||||
return
|
|
||||||
|
|
||||||
down_proj = model.model.layers[0].mlp.down_proj
|
|
||||||
print(f"✓ Got layer: {down_proj.in_features} -> {down_proj.out_features}")
|
|
||||||
|
|
||||||
batch_size, seq_len = 1, 4
|
|
||||||
test_input = torch.randn(batch_size, seq_len, down_proj.in_features, device="cuda", dtype=torch.float16)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
normal_output = down_proj(test_input)
|
|
||||||
|
|
||||||
print(f"✓ Normal inference: {normal_output.shape}")
|
|
||||||
|
|
||||||
import copy
|
|
||||||
test_linear = nn.Linear(down_proj.in_features, down_proj.out_features, bias=False)
|
|
||||||
test_linear.load_state_dict(copy.deepcopy(down_proj.state_dict()))
|
|
||||||
test_linear.to("cuda", dtype=torch.float16)
|
|
||||||
test_linear.eval()
|
|
||||||
|
|
||||||
manager = OffloadManager(max_gpu_tensors=2)
|
|
||||||
offloaded_layer = OffloadModuleWrapper(test_linear, manager)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
offload_output = offloaded_layer(test_input)
|
|
||||||
|
|
||||||
print(f"✓ Offload inference: {offload_output.shape}")
|
|
||||||
|
|
||||||
stats = manager.get_stats()
|
|
||||||
print(f"✓ Stats: {stats}")
|
|
||||||
|
|
||||||
diff = (offload_output - normal_output).abs().max().item()
|
|
||||||
print(f"✓ Max diff: {diff:.6f}")
|
|
||||||
|
|
||||||
assert diff < 1e-5
|
|
||||||
print("PASSED\n")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Part 3: 测试套件 - 同步分析
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def test_9_sync_behavior_analysis():
|
|
||||||
"""分析同步传输 vs 异步传输"""
|
|
||||||
print("\n=== Test 9: Sync Behavior Analysis ===")
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
in_features = 4096
|
|
||||||
out_features = 12244
|
|
||||||
seqlen = 16384
|
|
||||||
chunk_size = 4096
|
|
||||||
|
|
||||||
print(f"Config: in={in_features}, out={out_features}, seqlen={seqlen}, chunk={chunk_size}")
|
|
||||||
print(f"Num chunks: {seqlen // chunk_size}")
|
|
||||||
|
|
||||||
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
|
|
||||||
|
|
||||||
# 同步版本
|
|
||||||
print(f"\n--- 同步传输 (non_blocking=False) ---")
|
|
||||||
layer_sync = ChunkedOffloadLinear(in_features, out_features, chunk_size, non_blocking=False)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
start = time.time()
|
|
||||||
_ = layer_sync(x)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
sync_time_ms = (time.time() - start) * 1000
|
|
||||||
|
|
||||||
stats_sync = layer_sync.manager.get_stats()
|
|
||||||
print(f"总时间: {sync_time_ms:.2f} ms")
|
|
||||||
print(f"传输时间: {stats_sync['total_transfer_time_ms']:.2f} ms")
|
|
||||||
print(f"计算时间: {sync_time_ms - stats_sync['total_transfer_time_ms']:.2f} ms")
|
|
||||||
print(f"加载次数: {stats_sync['load_count']}")
|
|
||||||
|
|
||||||
# 异步版本
|
|
||||||
print(f"\n--- 异步传输 (non_blocking=True) ---")
|
|
||||||
layer_async = ChunkedOffloadLinear(in_features, out_features, chunk_size, non_blocking=True)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
start = time.time()
|
|
||||||
_ = layer_async(x)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
async_time_ms = (time.time() - start) * 1000
|
|
||||||
|
|
||||||
stats_async = layer_async.manager.get_stats()
|
|
||||||
print(f"总时间: {async_time_ms:.2f} ms")
|
|
||||||
print(f"传输时间: {stats_async['total_transfer_time_ms']:.2f} ms")
|
|
||||||
print(f"计算时间: {async_time_ms - stats_async['total_transfer_time_ms']:.2f} ms")
|
|
||||||
print(f"加载次数: {stats_async['load_count']}")
|
|
||||||
|
|
||||||
# 对比
|
|
||||||
print(f"\n--- 对比 ---")
|
|
||||||
print(f"总加速比: {sync_time_ms / async_time_ms:.2f}x")
|
|
||||||
|
|
||||||
if stats_async['total_transfer_time_ms'] > 0:
|
|
||||||
print(f"传输加速比: {stats_sync['total_transfer_time_ms'] / stats_async['total_transfer_time_ms']:.2f}x")
|
|
||||||
|
|
||||||
print("\n关键发现:")
|
|
||||||
print(f" 1. 同步传输阻塞 CPU 线程")
|
|
||||||
print(f" 2. 异步传输可提高吞吐量")
|
|
||||||
print(f" 3. 首次运行包含 JIT 编译开销")
|
|
||||||
print("PASSED\n")
|
|
||||||
|
|
||||||
|
|
||||||
def test_10_profiler_analysis():
|
|
||||||
"""使用 Profiler 分析内核执行"""
|
|
||||||
print("\n=== Test 10: Profiler Analysis ===")
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available, skipping")
|
|
||||||
return
|
|
||||||
|
|
||||||
in_features = 4096
|
|
||||||
out_features = 12244
|
|
||||||
seqlen = 16384
|
|
||||||
chunk_size = 4096
|
|
||||||
|
|
||||||
layer = ChunkedOffloadLinear(in_features, out_features, chunk_size)
|
|
||||||
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
|
|
||||||
|
|
||||||
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p:
|
|
||||||
with torch.no_grad():
|
|
||||||
_ = layer(x)
|
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
kernel_counts = {}
|
|
||||||
for event in p.key_averages():
|
|
||||||
if event.device_type == torch.profiler.DeviceType.CUDA:
|
|
||||||
name = event.key
|
|
||||||
kernel_counts[name] = kernel_counts.get(name, 0) + 1
|
|
||||||
|
|
||||||
print(f"内核调用统计:")
|
|
||||||
print(f"{'内核类型':<50} {'调用次数':<10}")
|
|
||||||
print("-" * 60)
|
|
||||||
|
|
||||||
for name, count in sorted(kernel_counts.items(), key=lambda x: -x[1])[:15]:
|
|
||||||
name_short = name[:48]
|
|
||||||
print(f"{name_short:<50} {count:<10}")
|
|
||||||
|
|
||||||
memcpy_count = sum(count for name, count in kernel_counts.items() if 'memcpy' in name.lower())
|
|
||||||
print(f"\n分析:")
|
|
||||||
print(f" - 总共 {len(kernel_counts)} 种不同的 CUDA 内核")
|
|
||||||
print(f" - 总调用次数: {sum(kernel_counts.values())}")
|
|
||||||
print(f" - 内存拷贝: {memcpy_count} 次")
|
|
||||||
print("PASSED\n")
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# 主测试入口
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""运行所有测试"""
|
|
||||||
print("=" * 70)
|
|
||||||
print("OffloadedTensor 统一测试套件")
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
# 功能测试
|
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("功能测试 (Tests 1-4)")
|
|
||||||
print("=" * 70)
|
|
||||||
test_1_basic_offloaded_tensor()
|
|
||||||
test_2_mlp_with_offload()
|
|
||||||
test_3_lru_eviction()
|
|
||||||
test_4_correctness()
|
|
||||||
|
|
||||||
# 性能测试
|
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("性能测试 (Tests 5-8)")
|
|
||||||
print("=" * 70)
|
|
||||||
test_5_memory_analysis()
|
|
||||||
test_6_long_sequence()
|
|
||||||
test_7_performance_comparison()
|
|
||||||
test_8_transformers_layer()
|
|
||||||
|
|
||||||
# 同步分析
|
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("同步分析 (Tests 9-10)")
|
|
||||||
print("=" * 70)
|
|
||||||
test_9_sync_behavior_analysis()
|
|
||||||
test_10_profiler_analysis()
|
|
||||||
|
|
||||||
print("=" * 70)
|
|
||||||
print("所有测试完成!")
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,198 +0,0 @@
|
|||||||
"""Test for torch distributed port conflict fix.
|
|
||||||
|
|
||||||
This test verifies that:
|
|
||||||
1. Multiple independent processes can run simultaneously (dynamic port allocation)
|
|
||||||
2. Sequential LLM creation in same process works (proper cleanup)
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Test parallel processes (requires 2 GPUs)
|
|
||||||
python tests/test_port_conflict.py --model ~/models/Qwen3-4B --gpus 4,5 --test parallel
|
|
||||||
|
|
||||||
# Test sequential creation in same process
|
|
||||||
CUDA_VISIBLE_DEVICES=4 python tests/test_port_conflict.py --model ~/models/Qwen3-4B --test sequential
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
def test_sequential_creation(model_path: str, enable_offload: bool = True):
|
|
||||||
"""Test creating multiple LLM instances sequentially in same process."""
|
|
||||||
# Add project root to path
|
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
sys.path.insert(0, project_root)
|
|
||||||
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("Test: Sequential LLM Creation (same process)")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
for i in range(3):
|
|
||||||
print(f"\n--- Creating LLM instance {i+1}/3 ---")
|
|
||||||
|
|
||||||
llm_kwargs = {"enable_cpu_offload": enable_offload}
|
|
||||||
if enable_offload:
|
|
||||||
llm_kwargs["num_gpu_blocks"] = 2
|
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
|
||||||
|
|
||||||
# Simple generation
|
|
||||||
outputs = llm.generate(
|
|
||||||
["Hello, how are you?"],
|
|
||||||
SamplingParams(max_tokens=20)
|
|
||||||
)
|
|
||||||
print(f"Output: {outputs[0]['text'][:50]}...")
|
|
||||||
|
|
||||||
# Explicit cleanup
|
|
||||||
llm.close()
|
|
||||||
print(f"Instance {i+1} closed successfully")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("PASSED: test_sequential_creation")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
def test_context_manager(model_path: str, enable_offload: bool = True):
|
|
||||||
"""Test LLM with context manager."""
|
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
sys.path.insert(0, project_root)
|
|
||||||
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print("Test: Context Manager")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
for i in range(2):
|
|
||||||
print(f"\n--- Context manager instance {i+1}/2 ---")
|
|
||||||
|
|
||||||
llm_kwargs = {"enable_cpu_offload": enable_offload}
|
|
||||||
if enable_offload:
|
|
||||||
llm_kwargs["num_gpu_blocks"] = 2
|
|
||||||
|
|
||||||
with LLM(model_path, **llm_kwargs) as llm:
|
|
||||||
outputs = llm.generate(
|
|
||||||
["What is 2+2?"],
|
|
||||||
SamplingParams(max_tokens=20)
|
|
||||||
)
|
|
||||||
print(f"Output: {outputs[0]['text'][:50]}...")
|
|
||||||
|
|
||||||
print(f"Instance {i+1} auto-closed via context manager")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("PASSED: test_context_manager")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
|
|
||||||
def test_parallel_processes(model_path: str, gpus: str, enable_offload: bool = True):
|
|
||||||
"""Test running multiple nanovllm processes in parallel."""
|
|
||||||
gpu_list = [int(g.strip()) for g in gpus.split(",")]
|
|
||||||
if len(gpu_list) < 2:
|
|
||||||
print("ERROR: Need at least 2 GPUs for parallel test")
|
|
||||||
return False
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Test: Parallel Processes (GPUs: {gpu_list})")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
|
|
||||||
# Script to run in each subprocess
|
|
||||||
script = f'''
|
|
||||||
import sys
|
|
||||||
sys.path.insert(0, "{project_root}")
|
|
||||||
import os
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
|
|
||||||
gpu = os.environ.get("CUDA_VISIBLE_DEVICES", "?")
|
|
||||||
print(f"[GPU {{gpu}}] Starting LLM...")
|
|
||||||
|
|
||||||
llm_kwargs = {{"enable_cpu_offload": {enable_offload}}}
|
|
||||||
if {enable_offload}:
|
|
||||||
llm_kwargs["num_gpu_blocks"] = 2
|
|
||||||
|
|
||||||
llm = LLM("{model_path}", **llm_kwargs)
|
|
||||||
print(f"[GPU {{gpu}}] LLM initialized, generating...")
|
|
||||||
|
|
||||||
outputs = llm.generate(["Hello world"], SamplingParams(max_tokens=10))
|
|
||||||
print(f"[GPU {{gpu}}] Output: {{outputs[0]['text'][:30]}}...")
|
|
||||||
|
|
||||||
llm.close()
|
|
||||||
print(f"[GPU {{gpu}}] Done")
|
|
||||||
'''
|
|
||||||
|
|
||||||
# Start processes on different GPUs
|
|
||||||
procs = []
|
|
||||||
for i, gpu in enumerate(gpu_list[:2]): # Use first 2 GPUs
|
|
||||||
print(f"\nStarting process on GPU {gpu}...")
|
|
||||||
env = os.environ.copy()
|
|
||||||
env["CUDA_VISIBLE_DEVICES"] = str(gpu)
|
|
||||||
|
|
||||||
p = subprocess.Popen(
|
|
||||||
[sys.executable, "-c", script],
|
|
||||||
env=env,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.STDOUT,
|
|
||||||
text=True
|
|
||||||
)
|
|
||||||
procs.append((gpu, p))
|
|
||||||
time.sleep(2) # Stagger starts to see concurrent running
|
|
||||||
|
|
||||||
# Wait and collect results
|
|
||||||
all_passed = True
|
|
||||||
for gpu, p in procs:
|
|
||||||
stdout, _ = p.communicate(timeout=300)
|
|
||||||
print(f"\n--- GPU {gpu} output ---")
|
|
||||||
print(stdout)
|
|
||||||
|
|
||||||
if p.returncode != 0:
|
|
||||||
print(f"ERROR: GPU {gpu} process failed with code {p.returncode}")
|
|
||||||
all_passed = False
|
|
||||||
else:
|
|
||||||
print(f"GPU {gpu} process completed successfully")
|
|
||||||
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
if all_passed:
|
|
||||||
print("PASSED: test_parallel_processes")
|
|
||||||
else:
|
|
||||||
print("FAILED: test_parallel_processes")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
return all_passed
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Test port conflict fix")
|
|
||||||
parser.add_argument("--model", "-m", required=True, help="Path to model")
|
|
||||||
parser.add_argument("--gpus", default="0,1", help="GPUs to use for parallel test (comma-separated)")
|
|
||||||
parser.add_argument("--test", choices=["sequential", "context", "parallel", "all"],
|
|
||||||
default="all", help="Which test to run")
|
|
||||||
parser.add_argument("--no-offload", action="store_true", help="Disable CPU offload")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
enable_offload = not args.no_offload
|
|
||||||
model_path = os.path.expanduser(args.model)
|
|
||||||
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"CPU Offload: {enable_offload}")
|
|
||||||
print(f"GPUs for parallel test: {args.gpus}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
if args.test in ["sequential", "all"]:
|
|
||||||
test_sequential_creation(model_path, enable_offload)
|
|
||||||
print()
|
|
||||||
|
|
||||||
if args.test in ["context", "all"]:
|
|
||||||
test_context_manager(model_path, enable_offload)
|
|
||||||
print()
|
|
||||||
|
|
||||||
if args.test in ["parallel", "all"]:
|
|
||||||
test_parallel_processes(model_path, args.gpus, enable_offload)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,409 +0,0 @@
|
|||||||
"""
|
|
||||||
RULER benchmark comprehensive test for LLM.
|
|
||||||
|
|
||||||
Tests multiple RULER tasks:
|
|
||||||
- NIAH (Needle-In-A-Haystack): single, multikey, multiquery, multivalue
|
|
||||||
- QA (Question Answering): qa_1, qa_2
|
|
||||||
- CWE (Common Word Extraction)
|
|
||||||
- FWE (Frequent Word Extraction)
|
|
||||||
- VT (Variable Tracking)
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Test all datasets with 2 samples each (debug mode)
|
|
||||||
python tests/test_ruler.py --enable-offload --num-samples 2
|
|
||||||
|
|
||||||
# Test specific datasets
|
|
||||||
python tests/test_ruler.py --enable-offload --datasets niah_single_1,qa_1
|
|
||||||
|
|
||||||
# Test all samples in all datasets
|
|
||||||
python tests/test_ruler.py --enable-offload
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import re
|
|
||||||
import gc
|
|
||||||
import time
|
|
||||||
import torch
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Dict, Tuple, Optional
|
|
||||||
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Constants
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
DEFAULT_DATA_DIR = Path(__file__).parent / "data/ruler_64k"
|
|
||||||
DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct")
|
|
||||||
# Note: max_model_len must be > max_input_len to leave room for output tokens
|
|
||||||
# 64k benchmark has inputs up to 65536 tokens, so we need 65536 + 128 = 65664
|
|
||||||
DEFAULT_MAX_MODEL_LEN = 65664
|
|
||||||
DEFAULT_MAX_NEW_TOKENS = 128 # Larger for multi-value tasks
|
|
||||||
|
|
||||||
# Task categories for evaluation
|
|
||||||
NIAH_TASKS = ["niah_single_1", "niah_single_2", "niah_single_3",
|
|
||||||
"niah_multikey_1", "niah_multikey_2", "niah_multikey_3",
|
|
||||||
"niah_multiquery", "niah_multivalue"]
|
|
||||||
QA_TASKS = ["qa_1", "qa_2"]
|
|
||||||
RECALL_TASKS = ["cwe", "fwe", "vt"]
|
|
||||||
|
|
||||||
ALL_TASKS = NIAH_TASKS + QA_TASKS + RECALL_TASKS
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Data Loading
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def load_samples(filepath: Path, indices: Optional[List[int]] = None) -> List[dict]:
|
|
||||||
"""Load samples from a JSONL file."""
|
|
||||||
if not filepath.exists():
|
|
||||||
raise FileNotFoundError(f"Data file not found: {filepath}")
|
|
||||||
|
|
||||||
samples = []
|
|
||||||
with open(filepath) as f:
|
|
||||||
for i, line in enumerate(f):
|
|
||||||
if indices is None or i in indices:
|
|
||||||
sample = json.loads(line)
|
|
||||||
sample["_local_idx"] = i
|
|
||||||
samples.append(sample)
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def count_samples(filepath: Path) -> int:
|
|
||||||
"""Count total samples in JSONL file."""
|
|
||||||
with open(filepath) as f:
|
|
||||||
return sum(1 for _ in f)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Evaluation Functions (Following RULER Official Metrics)
|
|
||||||
# Ref: https://github.com/NVIDIA/RULER/blob/main/scripts/eval/synthetic/constants.py
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def string_match_all(output_text: str, expected_list: List[str]) -> float:
|
|
||||||
"""
|
|
||||||
RULER official metric for NIAH, VT, CWE, FWE tasks.
|
|
||||||
|
|
||||||
Formula: sum([1.0 if r.lower() in pred.lower() else 0.0 for r in ref]) / len(ref)
|
|
||||||
|
|
||||||
Returns recall score (0.0 to 1.0): fraction of expected values found in output.
|
|
||||||
"""
|
|
||||||
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
|
||||||
output_lower = output_clean.lower()
|
|
||||||
|
|
||||||
if not expected_list:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
found = sum(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list)
|
|
||||||
return found / len(expected_list)
|
|
||||||
|
|
||||||
|
|
||||||
def string_match_part(output_text: str, expected_list: List[str]) -> float:
|
|
||||||
"""
|
|
||||||
RULER official metric for QA tasks.
|
|
||||||
|
|
||||||
Formula: max([1.0 if r.lower() in pred.lower() else 0.0 for r in ref])
|
|
||||||
|
|
||||||
Returns 1.0 if ANY expected value is found, 0.0 otherwise.
|
|
||||||
"""
|
|
||||||
output_clean = output_text.replace('<|im_end|>', '').replace('\r', ' ').replace('\n', ' ')
|
|
||||||
output_lower = output_clean.lower()
|
|
||||||
|
|
||||||
if not expected_list:
|
|
||||||
return 1.0
|
|
||||||
|
|
||||||
return max(1.0 if exp.strip().lower() in output_lower else 0.0 for exp in expected_list)
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_output(output_text: str, expected_outputs: List[str], task_name: str) -> Tuple[bool, float]:
|
|
||||||
"""
|
|
||||||
Evaluate model output using RULER official metrics.
|
|
||||||
|
|
||||||
- QA tasks: string_match_part (any match = full score)
|
|
||||||
- All other tasks: string_match_all (recall-based score)
|
|
||||||
|
|
||||||
Returns (passed, score) where passed = score >= 0.5
|
|
||||||
"""
|
|
||||||
if task_name in QA_TASKS:
|
|
||||||
score = string_match_part(output_text, expected_outputs)
|
|
||||||
else:
|
|
||||||
# NIAH, VT, CWE, FWE all use string_match_all
|
|
||||||
score = string_match_all(output_text, expected_outputs)
|
|
||||||
|
|
||||||
passed = score >= 0.5 # Consider pass if score >= 50%
|
|
||||||
return passed, score
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test Runner
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def run_task_test(
|
|
||||||
llm: LLM,
|
|
||||||
task_name: str,
|
|
||||||
data_dir: Path,
|
|
||||||
sample_indices: Optional[List[int]] = None,
|
|
||||||
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
|
||||||
verbose: bool = True,
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
Run test for a single RULER task.
|
|
||||||
|
|
||||||
Returns dict with: task, correct, total, score, results
|
|
||||||
"""
|
|
||||||
data_file = data_dir / task_name / "validation.jsonl"
|
|
||||||
samples = load_samples(data_file, sample_indices)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\n Testing {task_name}: {len(samples)} samples")
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.1,
|
|
||||||
max_tokens=max_new_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
correct = 0
|
|
||||||
total_score = 0.0
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for sample in samples:
|
|
||||||
idx = sample.get("index", sample["_local_idx"])
|
|
||||||
prompt = sample["input"]
|
|
||||||
expected = sample["outputs"]
|
|
||||||
|
|
||||||
# Generate
|
|
||||||
outputs = llm.generate([prompt], sampling_params, use_tqdm=False)
|
|
||||||
output_text = outputs[0]["text"]
|
|
||||||
|
|
||||||
# Evaluate
|
|
||||||
passed, score = evaluate_output(output_text, expected, task_name)
|
|
||||||
if passed:
|
|
||||||
correct += 1
|
|
||||||
total_score += score
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"index": idx,
|
|
||||||
"expected": expected,
|
|
||||||
"output": output_text[:200],
|
|
||||||
"passed": passed,
|
|
||||||
"score": score,
|
|
||||||
})
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
status = "PASS" if passed else "FAIL"
|
|
||||||
exp_preview = str(expected[0])[:30] if expected else "N/A"
|
|
||||||
out_preview = output_text[:50].replace('\n', ' ')
|
|
||||||
print(f" [{idx}] {status} (score={score:.2f}) exp={exp_preview}... out={out_preview}...")
|
|
||||||
|
|
||||||
avg_score = total_score / len(samples) if samples else 0.0
|
|
||||||
|
|
||||||
return {
|
|
||||||
"task": task_name,
|
|
||||||
"correct": correct,
|
|
||||||
"total": len(samples),
|
|
||||||
"accuracy": correct / len(samples) if samples else 0.0,
|
|
||||||
"avg_score": avg_score,
|
|
||||||
"results": results,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def run_ruler_benchmark(
|
|
||||||
model_path: str,
|
|
||||||
data_dir: Path,
|
|
||||||
datasets: Optional[List[str]] = None,
|
|
||||||
num_samples: Optional[int] = None,
|
|
||||||
max_model_len: int = DEFAULT_MAX_MODEL_LEN,
|
|
||||||
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
|
||||||
enable_cpu_offload: bool = False,
|
|
||||||
num_gpu_blocks: int = 4,
|
|
||||||
block_size: int = 1024,
|
|
||||||
num_kv_buffers: int = 4,
|
|
||||||
gpu_utilization: float = 0.9,
|
|
||||||
enforce_eager: bool = True,
|
|
||||||
verbose: bool = True,
|
|
||||||
sparse_policy: Optional[str] = None,
|
|
||||||
) -> Dict:
|
|
||||||
"""
|
|
||||||
Run RULER benchmark on multiple tasks.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to the model
|
|
||||||
data_dir: Directory containing task subdirectories
|
|
||||||
datasets: List of task names to test (None = all)
|
|
||||||
num_samples: Number of samples per task (None = all)
|
|
||||||
...other LLM config params...
|
|
||||||
sparse_policy: Sparse attention policy (FULL, QUEST, MINFERENCE, XATTN)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with overall results and per-task results
|
|
||||||
"""
|
|
||||||
# Determine tasks to run
|
|
||||||
if datasets is None:
|
|
||||||
tasks = [t for t in ALL_TASKS if (data_dir / t / "validation.jsonl").exists()]
|
|
||||||
else:
|
|
||||||
tasks = datasets
|
|
||||||
|
|
||||||
# Sample indices
|
|
||||||
sample_indices = list(range(num_samples)) if num_samples else None
|
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"RULER Benchmark")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Data dir: {data_dir}")
|
|
||||||
print(f"Tasks: {len(tasks)}")
|
|
||||||
print(f"Samples per task: {num_samples if num_samples else 'all'}")
|
|
||||||
print(f"CPU offload: {enable_cpu_offload}")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
# Initialize LLM
|
|
||||||
print("\nInitializing LLM...")
|
|
||||||
llm_kwargs = {
|
|
||||||
"max_model_len": max_model_len,
|
|
||||||
"max_num_batched_tokens": max_model_len,
|
|
||||||
"enforce_eager": enforce_eager,
|
|
||||||
"gpu_memory_utilization": gpu_utilization,
|
|
||||||
"kvcache_block_size": block_size,
|
|
||||||
"enable_cpu_offload": enable_cpu_offload,
|
|
||||||
}
|
|
||||||
if enable_cpu_offload:
|
|
||||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
|
||||||
llm_kwargs["num_kv_buffers"] = num_kv_buffers
|
|
||||||
if sparse_policy:
|
|
||||||
from nanovllm.config import SparsePolicyType
|
|
||||||
sparse_policy_type = SparsePolicyType[sparse_policy]
|
|
||||||
llm_kwargs["sparse_policy"] = sparse_policy_type
|
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
start_time = time.time()
|
|
||||||
task_results = []
|
|
||||||
|
|
||||||
for task_name in tasks:
|
|
||||||
result = run_task_test(
|
|
||||||
llm=llm,
|
|
||||||
task_name=task_name,
|
|
||||||
data_dir=data_dir,
|
|
||||||
sample_indices=sample_indices,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
verbose=verbose,
|
|
||||||
)
|
|
||||||
task_results.append(result)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f" -> {task_name}: {result['correct']}/{result['total']} "
|
|
||||||
f"({result['accuracy']*100:.1f}%) avg_score={result['avg_score']:.3f}")
|
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
del llm
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Aggregate results
|
|
||||||
total_correct = sum(r["correct"] for r in task_results)
|
|
||||||
total_samples = sum(r["total"] for r in task_results)
|
|
||||||
overall_accuracy = total_correct / total_samples if total_samples > 0 else 0.0
|
|
||||||
avg_score = sum(r["avg_score"] for r in task_results) / len(task_results) if task_results else 0.0
|
|
||||||
|
|
||||||
# Print summary
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"RULER Benchmark Results")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"\n{'Task':<20} {'Correct':<10} {'Accuracy':<12} {'Avg Score':<12}")
|
|
||||||
print(f"{'-'*54}")
|
|
||||||
for r in task_results:
|
|
||||||
print(f"{r['task']:<20} {r['correct']}/{r['total']:<7} {r['accuracy']*100:>6.1f}% {r['avg_score']:.3f}")
|
|
||||||
print(f"{'-'*54}")
|
|
||||||
print(f"{'TOTAL':<20} {total_correct}/{total_samples:<7} {overall_accuracy*100:>6.1f}% {avg_score:.3f}")
|
|
||||||
print(f"\nTime: {total_time:.1f}s")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_correct": total_correct,
|
|
||||||
"total_samples": total_samples,
|
|
||||||
"overall_accuracy": overall_accuracy,
|
|
||||||
"avg_score": avg_score,
|
|
||||||
"time": total_time,
|
|
||||||
"task_results": task_results,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# CLI Entry Point
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="RULER benchmark comprehensive test",
|
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument("--model", "-m", type=str, default=DEFAULT_MODEL,
|
|
||||||
help=f"Path to model (default: {DEFAULT_MODEL})")
|
|
||||||
parser.add_argument("--data-dir", type=str, default=str(DEFAULT_DATA_DIR),
|
|
||||||
help=f"Path to data directory (default: {DEFAULT_DATA_DIR})")
|
|
||||||
parser.add_argument("--datasets", type=str, default="",
|
|
||||||
help="Comma-separated list of datasets to test (default: all)")
|
|
||||||
parser.add_argument("--num-samples", type=int, default=0,
|
|
||||||
help="Number of samples per dataset (default: 0 = all)")
|
|
||||||
parser.add_argument("--max-model-len", type=int, default=DEFAULT_MAX_MODEL_LEN,
|
|
||||||
help=f"Maximum model context length (default: {DEFAULT_MAX_MODEL_LEN})")
|
|
||||||
parser.add_argument("--max-new-tokens", type=int, default=DEFAULT_MAX_NEW_TOKENS,
|
|
||||||
help=f"Maximum tokens to generate (default: {DEFAULT_MAX_NEW_TOKENS})")
|
|
||||||
parser.add_argument("--enable-offload", action="store_true",
|
|
||||||
help="Enable CPU offload mode")
|
|
||||||
parser.add_argument("--num-gpu-blocks", type=int, default=4,
|
|
||||||
help="Number of GPU blocks for CPU offload (default: 4)")
|
|
||||||
parser.add_argument("--block-size", type=int, default=1024,
|
|
||||||
help="KV cache block size (default: 1024)")
|
|
||||||
parser.add_argument("--num-kv-buffers", type=int, default=4,
|
|
||||||
help="Number of KV buffers for ring buffer (default: 4)")
|
|
||||||
parser.add_argument("--gpu-utilization", type=float, default=0.9,
|
|
||||||
help="GPU memory utilization (default: 0.9)")
|
|
||||||
parser.add_argument("--use-cuda-graph", action="store_true",
|
|
||||||
help="Enable CUDA graph")
|
|
||||||
parser.add_argument("--quiet", "-q", action="store_true",
|
|
||||||
help="Quiet mode")
|
|
||||||
parser.add_argument("--sparse-policy", type=str, default="",
|
|
||||||
help="Sparse attention policy (FULL, QUEST, MINFERENCE, XATTN)")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Parse datasets
|
|
||||||
datasets = args.datasets.split(",") if args.datasets else None
|
|
||||||
num_samples = args.num_samples if args.num_samples > 0 else None
|
|
||||||
|
|
||||||
# Parse sparse policy
|
|
||||||
sparse_policy_str = args.sparse_policy.upper() if args.sparse_policy else None
|
|
||||||
|
|
||||||
results = run_ruler_benchmark(
|
|
||||||
model_path=os.path.expanduser(args.model),
|
|
||||||
data_dir=Path(args.data_dir),
|
|
||||||
datasets=datasets,
|
|
||||||
num_samples=num_samples,
|
|
||||||
max_model_len=args.max_model_len,
|
|
||||||
max_new_tokens=args.max_new_tokens,
|
|
||||||
enable_cpu_offload=args.enable_offload,
|
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
|
||||||
block_size=args.block_size,
|
|
||||||
num_kv_buffers=args.num_kv_buffers,
|
|
||||||
gpu_utilization=args.gpu_utilization,
|
|
||||||
enforce_eager=not args.use_cuda_graph,
|
|
||||||
verbose=not args.quiet,
|
|
||||||
sparse_policy=sparse_policy_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exit code
|
|
||||||
if results["overall_accuracy"] >= 0.5:
|
|
||||||
print("test_ruler: PASSED")
|
|
||||||
else:
|
|
||||||
print(f"test_ruler: FAILED (accuracy={results['overall_accuracy']*100:.1f}%)")
|
|
||||||
exit(1)
|
|
||||||
@@ -1,527 +0,0 @@
|
|||||||
"""
|
|
||||||
RULER NIAH benchmark test for LLM.
|
|
||||||
|
|
||||||
Tests: Long context retrieval capability using pre-generated RULER benchmark data.
|
|
||||||
The NIAH (Needle-In-A-Haystack) task tests the model's ability to retrieve a
|
|
||||||
specific magic number from a large context (~32K tokens).
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Test all samples with CPU offload
|
|
||||||
python tests/test_ruler_niah.py --enable-offload
|
|
||||||
|
|
||||||
# Test specific samples
|
|
||||||
python tests/test_ruler_niah.py --sample-indices 0,1,2 --enable-offload
|
|
||||||
|
|
||||||
# Test with custom model
|
|
||||||
python tests/test_ruler_niah.py --model /path/to/model --enable-offload
|
|
||||||
|
|
||||||
# Group mode: test in batches with separate LLM initialization per group
|
|
||||||
python tests/test_ruler_niah.py --enable-offload --group-size 5
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Tuple, Optional
|
|
||||||
|
|
||||||
from nanovllm import LLM, SamplingParams
|
|
||||||
from utils import check_needle_answer
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Constants
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
DEFAULT_DATA_FILE = Path(__file__).parent / "data/ruler_niah/niah_single_1_32k.jsonl"
|
|
||||||
DEFAULT_MODEL = os.path.expanduser("~/models/Llama-3.1-8B-Instruct")
|
|
||||||
DEFAULT_MAX_MODEL_LEN = 32768
|
|
||||||
DEFAULT_MAX_NEW_TOKENS = 50
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Data Loading
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def load_ruler_samples(filepath: Path, indices: Optional[List[int]] = None) -> List[dict]:
|
|
||||||
"""
|
|
||||||
Load RULER NIAH samples from a JSONL file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath: Path to the JSONL file
|
|
||||||
indices: Optional list of sample indices to load. If None, load all.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of sample dicts with keys: index, input, outputs, length
|
|
||||||
"""
|
|
||||||
if not filepath.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Data file not found: {filepath}\n"
|
|
||||||
f"Please copy RULER NIAH data to this location. See docs/ruler_niah_standalone_test.md"
|
|
||||||
)
|
|
||||||
|
|
||||||
samples = []
|
|
||||||
with open(filepath) as f:
|
|
||||||
for i, line in enumerate(f):
|
|
||||||
if indices is None or i in indices:
|
|
||||||
sample = json.loads(line)
|
|
||||||
samples.append(sample)
|
|
||||||
|
|
||||||
if not samples:
|
|
||||||
raise ValueError(f"No samples loaded from {filepath}")
|
|
||||||
|
|
||||||
return samples
|
|
||||||
|
|
||||||
|
|
||||||
def count_samples(filepath: Path) -> int:
|
|
||||||
"""Count total samples in JSONL file."""
|
|
||||||
with open(filepath) as f:
|
|
||||||
return sum(1 for _ in f)
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Test Function
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def run_ruler_niah_test(
|
|
||||||
model_path: str,
|
|
||||||
data_file: Path,
|
|
||||||
sample_indices: Optional[List[int]] = None,
|
|
||||||
max_model_len: int = DEFAULT_MAX_MODEL_LEN,
|
|
||||||
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
|
||||||
enable_cpu_offload: bool = False,
|
|
||||||
num_gpu_blocks: int = 4,
|
|
||||||
block_size: int = 1024,
|
|
||||||
gpu_utilization: float = 0.9,
|
|
||||||
enforce_eager: bool = True,
|
|
||||||
verbose: bool = True,
|
|
||||||
) -> Tuple[int, int]:
|
|
||||||
"""
|
|
||||||
Run RULER NIAH test on loaded samples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to the model
|
|
||||||
data_file: Path to JSONL data file
|
|
||||||
sample_indices: List of sample indices to test (None = all)
|
|
||||||
max_model_len: Maximum model context length
|
|
||||||
max_new_tokens: Maximum tokens to generate
|
|
||||||
enable_cpu_offload: Enable CPU offload mode
|
|
||||||
num_gpu_blocks: Number of GPU blocks for offload
|
|
||||||
block_size: KV cache block size
|
|
||||||
gpu_utilization: GPU memory utilization fraction
|
|
||||||
enforce_eager: Disable CUDA graphs
|
|
||||||
verbose: Print detailed output
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(correct, total): Number of correct and total samples
|
|
||||||
"""
|
|
||||||
# Load samples
|
|
||||||
samples = load_ruler_samples(data_file, sample_indices)
|
|
||||||
total = len(samples)
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"RULER NIAH Test")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Data file: {data_file}")
|
|
||||||
print(f"Samples: {total}")
|
|
||||||
print(f"Max model len: {max_model_len}")
|
|
||||||
print(f"Max new tokens: {max_new_tokens}")
|
|
||||||
print(f"CPU offload: {enable_cpu_offload}")
|
|
||||||
if enable_cpu_offload:
|
|
||||||
print(f" num_gpu_blocks: {num_gpu_blocks}")
|
|
||||||
print(f" block_size: {block_size}")
|
|
||||||
print(f"Enforce eager: {enforce_eager}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
# Check max_model_len vs data length
|
|
||||||
max_data_len = max(s.get("length", 0) for s in samples)
|
|
||||||
if max_model_len < max_data_len:
|
|
||||||
print(f"WARNING: max_model_len ({max_model_len}) < max data length ({max_data_len})")
|
|
||||||
print(f" This may cause truncation or errors.\n")
|
|
||||||
|
|
||||||
# Initialize LLM
|
|
||||||
if verbose:
|
|
||||||
print("Initializing LLM...")
|
|
||||||
|
|
||||||
llm_kwargs = {
|
|
||||||
"max_model_len": max_model_len,
|
|
||||||
"max_num_batched_tokens": max_model_len,
|
|
||||||
"enforce_eager": enforce_eager,
|
|
||||||
"gpu_memory_utilization": gpu_utilization,
|
|
||||||
"kvcache_block_size": block_size,
|
|
||||||
"enable_cpu_offload": enable_cpu_offload,
|
|
||||||
}
|
|
||||||
|
|
||||||
if enable_cpu_offload:
|
|
||||||
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
|
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
|
||||||
|
|
||||||
# Sampling params
|
|
||||||
# Note: nano-vllm doesn't support greedy (temperature=0), use low temperature instead
|
|
||||||
sampling_params = SamplingParams(
|
|
||||||
temperature=0.1, # Low temperature for near-deterministic output
|
|
||||||
max_tokens=max_new_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test each sample
|
|
||||||
correct = 0
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for i, sample in enumerate(samples):
|
|
||||||
sample_idx = sample.get("index", i)
|
|
||||||
prompt = sample["input"]
|
|
||||||
expected = sample["outputs"][0]
|
|
||||||
data_len = sample.get("length", "unknown")
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
print(f"\nSample {sample_idx}: Expected={expected}, Length={data_len}")
|
|
||||||
|
|
||||||
# Generate
|
|
||||||
outputs = llm.generate([prompt], sampling_params, use_tqdm=False)
|
|
||||||
output_text = outputs[0]["text"]
|
|
||||||
output_tokens = outputs[0]["token_ids"]
|
|
||||||
|
|
||||||
# Check result
|
|
||||||
passed = check_needle_answer(output_text, expected)
|
|
||||||
if passed:
|
|
||||||
correct += 1
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"index": sample_idx,
|
|
||||||
"expected": expected,
|
|
||||||
"output": output_text,
|
|
||||||
"passed": passed,
|
|
||||||
})
|
|
||||||
|
|
||||||
if verbose:
|
|
||||||
status = "PASS" if passed else "FAIL"
|
|
||||||
output_preview = output_text[:100].replace('\n', ' ')
|
|
||||||
print(f" Output ({len(output_tokens)} tokens): {output_preview}...")
|
|
||||||
print(f" Status: {status}")
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
if verbose:
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Results: {correct}/{total} PASSED ({100*correct/total:.1f}%)")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
if correct < total:
|
|
||||||
print("Failed samples:")
|
|
||||||
for r in results:
|
|
||||||
if not r["passed"]:
|
|
||||||
print(f" Sample {r['index']}: expected={r['expected']}, got={r['output'][:50]}...")
|
|
||||||
|
|
||||||
return correct, total
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Grouped Test Function
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def run_grouped_test(
|
|
||||||
model_path: str,
|
|
||||||
data_file: Path,
|
|
||||||
group_size: int = 5,
|
|
||||||
total_samples: Optional[int] = None,
|
|
||||||
max_model_len: int = DEFAULT_MAX_MODEL_LEN,
|
|
||||||
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
|
||||||
enable_cpu_offload: bool = False,
|
|
||||||
num_gpu_blocks: int = 4,
|
|
||||||
block_size: int = 1024,
|
|
||||||
gpu_utilization: float = 0.9,
|
|
||||||
enforce_eager: bool = True,
|
|
||||||
) -> Tuple[int, int, List[dict]]:
|
|
||||||
"""
|
|
||||||
Run RULER NIAH test in groups, with separate LLM initialization per group.
|
|
||||||
|
|
||||||
This mode is useful for:
|
|
||||||
- Avoiding state accumulation issues
|
|
||||||
- Testing LLM initialization stability
|
|
||||||
- Running large-scale tests with memory cleanup between groups
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: Path to the model
|
|
||||||
data_file: Path to JSONL data file
|
|
||||||
group_size: Number of samples per group
|
|
||||||
total_samples: Total samples to test (None = all in file)
|
|
||||||
Other args: Same as run_ruler_niah_test
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(total_correct, total_tested, group_results): Results summary
|
|
||||||
"""
|
|
||||||
import time
|
|
||||||
import gc
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# Count total samples in file
|
|
||||||
file_sample_count = count_samples(data_file)
|
|
||||||
if total_samples is None:
|
|
||||||
total_samples = file_sample_count
|
|
||||||
else:
|
|
||||||
total_samples = min(total_samples, file_sample_count)
|
|
||||||
|
|
||||||
num_groups = (total_samples + group_size - 1) // group_size
|
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"RULER NIAH Grouped Test")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"Model: {model_path}")
|
|
||||||
print(f"Data file: {data_file}")
|
|
||||||
print(f"Total samples: {total_samples}")
|
|
||||||
print(f"Group size: {group_size}")
|
|
||||||
print(f"Number of groups: {num_groups}")
|
|
||||||
print(f"CPU offload: {enable_cpu_offload}")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
total_correct = 0
|
|
||||||
total_tested = 0
|
|
||||||
group_results = []
|
|
||||||
all_failed = []
|
|
||||||
|
|
||||||
test_start_time = time.time()
|
|
||||||
|
|
||||||
for group_idx in range(num_groups):
|
|
||||||
start_idx = group_idx * group_size
|
|
||||||
end_idx = min(start_idx + group_size, total_samples)
|
|
||||||
sample_indices = list(range(start_idx, end_idx))
|
|
||||||
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"Group {group_idx + 1}/{num_groups}: Samples {start_idx}-{end_idx - 1}")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
|
|
||||||
group_start_time = time.time()
|
|
||||||
|
|
||||||
# Run test for this group
|
|
||||||
correct, tested = run_ruler_niah_test(
|
|
||||||
model_path=model_path,
|
|
||||||
data_file=data_file,
|
|
||||||
sample_indices=sample_indices,
|
|
||||||
max_model_len=max_model_len,
|
|
||||||
max_new_tokens=max_new_tokens,
|
|
||||||
enable_cpu_offload=enable_cpu_offload,
|
|
||||||
num_gpu_blocks=num_gpu_blocks,
|
|
||||||
block_size=block_size,
|
|
||||||
gpu_utilization=gpu_utilization,
|
|
||||||
enforce_eager=enforce_eager,
|
|
||||||
verbose=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
group_time = time.time() - group_start_time
|
|
||||||
|
|
||||||
total_correct += correct
|
|
||||||
total_tested += tested
|
|
||||||
|
|
||||||
group_result = {
|
|
||||||
"group": group_idx + 1,
|
|
||||||
"samples": f"{start_idx}-{end_idx - 1}",
|
|
||||||
"correct": correct,
|
|
||||||
"total": tested,
|
|
||||||
"accuracy": 100 * correct / tested if tested > 0 else 0,
|
|
||||||
"time": group_time,
|
|
||||||
}
|
|
||||||
group_results.append(group_result)
|
|
||||||
|
|
||||||
print(f"\nGroup {group_idx + 1} Summary: {correct}/{tested} PASSED ({group_result['accuracy']:.1f}%) in {group_time:.1f}s")
|
|
||||||
|
|
||||||
# Force cleanup between groups
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Small delay to ensure port is released
|
|
||||||
if group_idx < num_groups - 1:
|
|
||||||
time.sleep(3)
|
|
||||||
|
|
||||||
total_time = time.time() - test_start_time
|
|
||||||
|
|
||||||
# Final summary
|
|
||||||
print(f"\n{'='*60}")
|
|
||||||
print(f"FINAL SUMMARY")
|
|
||||||
print(f"{'='*60}")
|
|
||||||
print(f"\nGroup Results:")
|
|
||||||
print(f"{'Group':<8} {'Samples':<12} {'Result':<12} {'Accuracy':<10} {'Time':<10}")
|
|
||||||
print(f"{'-'*52}")
|
|
||||||
for r in group_results:
|
|
||||||
print(f"{r['group']:<8} {r['samples']:<12} {r['correct']}/{r['total']:<9} {r['accuracy']:.1f}%{'':<5} {r['time']:.1f}s")
|
|
||||||
|
|
||||||
print(f"{'-'*52}")
|
|
||||||
overall_accuracy = 100 * total_correct / total_tested if total_tested > 0 else 0
|
|
||||||
print(f"{'TOTAL':<8} {'0-' + str(total_tested-1):<12} {total_correct}/{total_tested:<9} {overall_accuracy:.1f}%{'':<5} {total_time:.1f}s")
|
|
||||||
print(f"{'='*60}\n")
|
|
||||||
|
|
||||||
return total_correct, total_tested, group_results
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# CLI Entry Point
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def parse_indices(s: str) -> List[int]:
|
|
||||||
"""Parse comma-separated indices like '0,1,2' or range like '0-4'."""
|
|
||||||
if not s:
|
|
||||||
return None
|
|
||||||
indices = []
|
|
||||||
for part in s.split(','):
|
|
||||||
if '-' in part:
|
|
||||||
start, end = part.split('-')
|
|
||||||
indices.extend(range(int(start), int(end) + 1))
|
|
||||||
else:
|
|
||||||
indices.append(int(part))
|
|
||||||
return indices
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="RULER NIAH benchmark test for long context LLM",
|
|
||||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
||||||
epilog="""
|
|
||||||
Examples:
|
|
||||||
# Test all samples with CPU offload (recommended for 24GB GPUs)
|
|
||||||
python tests/test_ruler_niah.py --enable-offload
|
|
||||||
|
|
||||||
# Test specific samples
|
|
||||||
python tests/test_ruler_niah.py --sample-indices 0,1,2 --enable-offload
|
|
||||||
|
|
||||||
# Test with CUDA graph enabled
|
|
||||||
python tests/test_ruler_niah.py --enable-offload --use-cuda-graph
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--model", "-m",
|
|
||||||
type=str,
|
|
||||||
default=DEFAULT_MODEL,
|
|
||||||
help=f"Path to model (default: {DEFAULT_MODEL})"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--data-file",
|
|
||||||
type=str,
|
|
||||||
default=str(DEFAULT_DATA_FILE),
|
|
||||||
help=f"Path to JSONL data file (default: {DEFAULT_DATA_FILE})"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--sample-indices",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Sample indices to test (e.g., '0,1,2' or '0-4'). Default: all"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-model-len",
|
|
||||||
type=int,
|
|
||||||
default=DEFAULT_MAX_MODEL_LEN,
|
|
||||||
help=f"Maximum model context length (default: {DEFAULT_MAX_MODEL_LEN})"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-new-tokens",
|
|
||||||
type=int,
|
|
||||||
default=DEFAULT_MAX_NEW_TOKENS,
|
|
||||||
help=f"Maximum tokens to generate (default: {DEFAULT_MAX_NEW_TOKENS})"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enable-offload",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CPU offload mode (required for 24GB GPUs with 32K context)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num-gpu-blocks",
|
|
||||||
type=int,
|
|
||||||
default=4,
|
|
||||||
help="Number of GPU blocks for CPU offload (default: 4)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--block-size",
|
|
||||||
type=int,
|
|
||||||
default=1024,
|
|
||||||
help="KV cache block size (default: 1024)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--gpu-utilization",
|
|
||||||
type=float,
|
|
||||||
default=0.9,
|
|
||||||
help="GPU memory utilization fraction (default: 0.9)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--enforce-eager",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Force eager execution, disable CUDA graphs (default: True)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--use-cuda-graph",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable CUDA graph (overrides --enforce-eager)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--verbose",
|
|
||||||
action="store_true",
|
|
||||||
default=True,
|
|
||||||
help="Print detailed output (default: True)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--quiet", "-q",
|
|
||||||
action="store_true",
|
|
||||||
help="Quiet mode, only print final result"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--group-size",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Enable grouped testing mode with specified group size. Each group initializes LLM separately. (default: 0 = disabled)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--total-samples",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
help="Total number of samples to test in group mode (default: 0 = all samples in file)"
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
# Process arguments
|
|
||||||
sample_indices = parse_indices(args.sample_indices)
|
|
||||||
enforce_eager = not args.use_cuda_graph
|
|
||||||
verbose = not args.quiet
|
|
||||||
|
|
||||||
# Check if group mode is enabled
|
|
||||||
if args.group_size > 0:
|
|
||||||
# Grouped testing mode
|
|
||||||
total_samples = args.total_samples if args.total_samples > 0 else None
|
|
||||||
correct, total, _ = run_grouped_test(
|
|
||||||
model_path=os.path.expanduser(args.model),
|
|
||||||
data_file=Path(args.data_file),
|
|
||||||
group_size=args.group_size,
|
|
||||||
total_samples=total_samples,
|
|
||||||
max_model_len=args.max_model_len,
|
|
||||||
max_new_tokens=args.max_new_tokens,
|
|
||||||
enable_cpu_offload=args.enable_offload,
|
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
|
||||||
block_size=args.block_size,
|
|
||||||
gpu_utilization=args.gpu_utilization,
|
|
||||||
enforce_eager=enforce_eager,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Standard testing mode
|
|
||||||
correct, total = run_ruler_niah_test(
|
|
||||||
model_path=os.path.expanduser(args.model),
|
|
||||||
data_file=Path(args.data_file),
|
|
||||||
sample_indices=sample_indices,
|
|
||||||
max_model_len=args.max_model_len,
|
|
||||||
max_new_tokens=args.max_new_tokens,
|
|
||||||
enable_cpu_offload=args.enable_offload,
|
|
||||||
num_gpu_blocks=args.num_gpu_blocks,
|
|
||||||
block_size=args.block_size,
|
|
||||||
gpu_utilization=args.gpu_utilization,
|
|
||||||
enforce_eager=enforce_eager,
|
|
||||||
verbose=verbose,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Final status
|
|
||||||
if correct == total:
|
|
||||||
print("test_ruler_niah: PASSED")
|
|
||||||
else:
|
|
||||||
print(f"test_ruler_niah: FAILED ({correct}/{total})")
|
|
||||||
exit(1)
|
|
||||||
@@ -1,242 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
#
|
|
||||||
# RULER NIAH Parallel Test Script
|
|
||||||
#
|
|
||||||
# Runs RULER NIAH benchmark across multiple GPUs in parallel.
|
|
||||||
# Each sample is tested independently (separate Python process per sample).
|
|
||||||
#
|
|
||||||
# Usage:
|
|
||||||
# ./tests/test_ruler_niah.sh [OPTIONS]
|
|
||||||
#
|
|
||||||
# Options:
|
|
||||||
# --gpus "0,1,2,3" GPUs to use (default: "0,1,2,3")
|
|
||||||
# --total N Total samples to test (default: 100)
|
|
||||||
# --model PATH Model path (default: ~/models/Llama-3.1-8B-Instruct)
|
|
||||||
# --output FILE Output log file (default: /tmp/ruler_niah_results.log)
|
|
||||||
#
|
|
||||||
|
|
||||||
# Note: Removed 'set -e' because ((var++)) returns 1 when var=0, which triggers exit
|
|
||||||
|
|
||||||
# Default configuration
|
|
||||||
GPUS="0,1,2,3"
|
|
||||||
TOTAL_SAMPLES=100
|
|
||||||
MODEL_PATH="$HOME/models/Llama-3.1-8B-Instruct"
|
|
||||||
OUTPUT_LOG="/tmp/ruler_niah_results.log"
|
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|
||||||
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
|
||||||
|
|
||||||
# Parse arguments
|
|
||||||
while [[ $# -gt 0 ]]; do
|
|
||||||
case $1 in
|
|
||||||
--gpus)
|
|
||||||
GPUS="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--total)
|
|
||||||
TOTAL_SAMPLES="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--model)
|
|
||||||
MODEL_PATH="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
--output)
|
|
||||||
OUTPUT_LOG="$2"
|
|
||||||
shift 2
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
echo "Unknown option: $1"
|
|
||||||
exit 1
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
# Convert GPU string to array
|
|
||||||
IFS=',' read -ra GPU_ARRAY <<< "$GPUS"
|
|
||||||
NUM_GPUS=${#GPU_ARRAY[@]}
|
|
||||||
|
|
||||||
echo "============================================================"
|
|
||||||
echo "RULER NIAH Parallel Test"
|
|
||||||
echo "============================================================"
|
|
||||||
echo "GPUs: ${GPUS} (${NUM_GPUS} GPUs)"
|
|
||||||
echo "Total samples: ${TOTAL_SAMPLES}"
|
|
||||||
echo "Model: ${MODEL_PATH}"
|
|
||||||
echo "Output log: ${OUTPUT_LOG}"
|
|
||||||
echo "Project root: ${PROJECT_ROOT}"
|
|
||||||
echo "============================================================"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Create output directory
|
|
||||||
mkdir -p "$(dirname "$OUTPUT_LOG")"
|
|
||||||
|
|
||||||
# Initialize result tracking
|
|
||||||
RESULT_DIR="/tmp/ruler_niah_results_$$"
|
|
||||||
mkdir -p "$RESULT_DIR"
|
|
||||||
|
|
||||||
# Function to run a single sample on a specific GPU
|
|
||||||
run_sample() {
|
|
||||||
local gpu=$1
|
|
||||||
local sample_idx=$2
|
|
||||||
local result_file="$RESULT_DIR/sample_${sample_idx}.result"
|
|
||||||
|
|
||||||
# Run test with unique port based on GPU
|
|
||||||
local port=$((2333 + gpu))
|
|
||||||
|
|
||||||
NANOVLLM_DIST_PORT=$port \
|
|
||||||
CUDA_VISIBLE_DEVICES=$gpu \
|
|
||||||
PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" \
|
|
||||||
python "$SCRIPT_DIR/test_ruler_niah.py" \
|
|
||||||
--model "$MODEL_PATH" \
|
|
||||||
--enable-offload \
|
|
||||||
--sample-indices "$sample_idx" \
|
|
||||||
--quiet \
|
|
||||||
2>&1
|
|
||||||
|
|
||||||
local exit_code=$?
|
|
||||||
if [ $exit_code -eq 0 ]; then
|
|
||||||
echo "PASS" > "$result_file"
|
|
||||||
else
|
|
||||||
echo "FAIL" > "$result_file"
|
|
||||||
fi
|
|
||||||
|
|
||||||
return $exit_code
|
|
||||||
}
|
|
||||||
|
|
||||||
# Function to run samples on a specific GPU
|
|
||||||
run_gpu_worker() {
|
|
||||||
local gpu=$1
|
|
||||||
local gpu_idx=$2
|
|
||||||
local log_file="$RESULT_DIR/gpu_${gpu}.log"
|
|
||||||
|
|
||||||
echo "[GPU $gpu] Starting worker (gpu_idx=$gpu_idx)" | tee -a "$log_file"
|
|
||||||
|
|
||||||
# Calculate which samples this GPU handles
|
|
||||||
local sample_idx=$gpu_idx
|
|
||||||
local pass_count=0
|
|
||||||
local fail_count=0
|
|
||||||
|
|
||||||
while [ $sample_idx -lt $TOTAL_SAMPLES ]; do
|
|
||||||
echo "[GPU $gpu] Testing sample $sample_idx..." | tee -a "$log_file"
|
|
||||||
|
|
||||||
local start_time=$(date +%s)
|
|
||||||
|
|
||||||
if run_sample $gpu $sample_idx >> "$log_file" 2>&1; then
|
|
||||||
echo "[GPU $gpu] Sample $sample_idx: PASS" | tee -a "$log_file"
|
|
||||||
((pass_count++))
|
|
||||||
else
|
|
||||||
echo "[GPU $gpu] Sample $sample_idx: FAIL" | tee -a "$log_file"
|
|
||||||
((fail_count++))
|
|
||||||
fi
|
|
||||||
|
|
||||||
local end_time=$(date +%s)
|
|
||||||
local duration=$((end_time - start_time))
|
|
||||||
echo "[GPU $gpu] Sample $sample_idx completed in ${duration}s" | tee -a "$log_file"
|
|
||||||
|
|
||||||
# Move to next sample for this GPU (stride by number of GPUs)
|
|
||||||
sample_idx=$((sample_idx + NUM_GPUS))
|
|
||||||
|
|
||||||
# Small delay to avoid port conflicts
|
|
||||||
sleep 2
|
|
||||||
done
|
|
||||||
|
|
||||||
echo "[GPU $gpu] Worker finished: $pass_count passed, $fail_count failed" | tee -a "$log_file"
|
|
||||||
echo "$pass_count $fail_count" > "$RESULT_DIR/gpu_${gpu}.summary"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Start time
|
|
||||||
START_TIME=$(date +%s)
|
|
||||||
echo "Starting parallel test at $(date '+%Y-%m-%d %H:%M:%S')"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Launch workers for each GPU in background
|
|
||||||
PIDS=()
|
|
||||||
for i in "${!GPU_ARRAY[@]}"; do
|
|
||||||
gpu=${GPU_ARRAY[$i]}
|
|
||||||
echo "Launching worker on GPU $gpu..."
|
|
||||||
run_gpu_worker $gpu $i &
|
|
||||||
PIDS+=($!)
|
|
||||||
done
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "All workers launched. Waiting for completion..."
|
|
||||||
echo "Monitor progress with: tail -f $RESULT_DIR/gpu_*.log"
|
|
||||||
echo ""
|
|
||||||
|
|
||||||
# Wait for all workers to complete
|
|
||||||
for pid in "${PIDS[@]}"; do
|
|
||||||
wait $pid
|
|
||||||
done
|
|
||||||
|
|
||||||
# End time
|
|
||||||
END_TIME=$(date +%s)
|
|
||||||
DURATION=$((END_TIME - START_TIME))
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "============================================================"
|
|
||||||
echo "FINAL RESULTS"
|
|
||||||
echo "============================================================"
|
|
||||||
|
|
||||||
# Aggregate results
|
|
||||||
TOTAL_PASS=0
|
|
||||||
TOTAL_FAIL=0
|
|
||||||
|
|
||||||
for gpu in "${GPU_ARRAY[@]}"; do
|
|
||||||
if [ -f "$RESULT_DIR/gpu_${gpu}.summary" ]; then
|
|
||||||
read pass fail < "$RESULT_DIR/gpu_${gpu}.summary"
|
|
||||||
TOTAL_PASS=$((TOTAL_PASS + pass))
|
|
||||||
TOTAL_FAIL=$((TOTAL_FAIL + fail))
|
|
||||||
echo "GPU $gpu: $pass passed, $fail failed"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
TOTAL_TESTED=$((TOTAL_PASS + TOTAL_FAIL))
|
|
||||||
if [ $TOTAL_TESTED -gt 0 ]; then
|
|
||||||
ACCURACY=$(echo "scale=1; $TOTAL_PASS * 100 / $TOTAL_TESTED" | bc)
|
|
||||||
else
|
|
||||||
ACCURACY="0.0"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "------------------------------------------------------------"
|
|
||||||
echo "Total: $TOTAL_PASS/$TOTAL_TESTED passed ($ACCURACY%)"
|
|
||||||
echo "Duration: ${DURATION}s ($(echo "scale=1; $DURATION / 60" | bc) minutes)"
|
|
||||||
echo "Throughput: $(echo "scale=2; $TOTAL_TESTED * 60 / $DURATION" | bc) samples/min"
|
|
||||||
echo "------------------------------------------------------------"
|
|
||||||
|
|
||||||
# Save detailed results
|
|
||||||
{
|
|
||||||
echo "RULER NIAH Parallel Test Results"
|
|
||||||
echo "================================"
|
|
||||||
echo "Date: $(date '+%Y-%m-%d %H:%M:%S')"
|
|
||||||
echo "GPUs: $GPUS"
|
|
||||||
echo "Total samples: $TOTAL_TESTED"
|
|
||||||
echo "Passed: $TOTAL_PASS"
|
|
||||||
echo "Failed: $TOTAL_FAIL"
|
|
||||||
echo "Accuracy: $ACCURACY%"
|
|
||||||
echo "Duration: ${DURATION}s"
|
|
||||||
echo ""
|
|
||||||
echo "Per-sample results:"
|
|
||||||
for i in $(seq 0 $((TOTAL_SAMPLES - 1))); do
|
|
||||||
if [ -f "$RESULT_DIR/sample_${i}.result" ]; then
|
|
||||||
result=$(cat "$RESULT_DIR/sample_${i}.result")
|
|
||||||
echo "Sample $i: $result"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
} > "$OUTPUT_LOG"
|
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "Detailed results saved to: $OUTPUT_LOG"
|
|
||||||
|
|
||||||
# Cleanup
|
|
||||||
# rm -rf "$RESULT_DIR"
|
|
||||||
|
|
||||||
# Exit with appropriate code
|
|
||||||
if [ $TOTAL_FAIL -eq 0 ]; then
|
|
||||||
echo ""
|
|
||||||
echo "test_ruler_niah.sh: ALL PASSED"
|
|
||||||
exit 0
|
|
||||||
else
|
|
||||||
echo ""
|
|
||||||
echo "test_ruler_niah.sh: $TOTAL_FAIL FAILED"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
Reference in New Issue
Block a user