[refactor] Delete unnesscessory test, and refacrtor the offload prefix cache.

This commit is contained in:
Zijie Tian
2026-01-05 20:31:42 +08:00
parent 247c5312d9
commit e554d5482b
20 changed files with 258 additions and 3630 deletions

View File

@@ -75,18 +75,12 @@ for hook in hooks:
hook.remove()
```
### Alignment Testing
Use `tests/test_align.py` to compare nanovllm with reference torch implementation:
```bash
python tests/test_align.py
```
### Reference Implementation
Key files:
- `tests/modeling_qwen3.py`: Reference Qwen3 implementation (torch + transformers only)
- `tests/test_align.py`: Compares attention outputs between nanovllm and reference
- `tests/test_needle_ref.py`: Reference needle test using custom Qwen3
- `tests/test_needle.py`: Needle-in-haystack test for nanovllm
### Common Pitfalls
@@ -179,7 +173,6 @@ memcpy_2d_async(gpu_buf, cpu_cache[:, block_id], dpitch, spitch, width, height,
**Files**:
- `csrc/sgdma_kernel.cu`, `csrc/sgdma.cpp`: CUDA extension
- `nanovllm/comm/sgdma.py`: Python API
- `tests/test_sgdma.py`: Standalone benchmark
- `kvcache/offload_engine.py`: Integration (4 methods updated)
### Integration Details
@@ -284,25 +277,53 @@ def _merge_output_kernel(...):
- Total GPU time: ~1,343 ms
- **Overall speedup with Triton**: 1.67x
### Correctness Verification
**Test**: `tests/test_chunked_attention.py`
- 12 test cases (6 configs × 2 dtypes)
- All tests PASS with max error < 0.01
- float16: max_diff=0.000488, mean_diff~0.00001
- bfloat16: max_diff=0.003906, mean_diff~0.0001
### Key Files
- `nanovllm/kvcache/chunked_attention.py`: Triton kernels + merge function
- `tests/test_chunked_attention.py`: Correctness tests
- `tests/test_attention_offload.py`: Performance profiling
## Known Issues and Fixes
### Partial Last Block Bug (FIXED ✓)
**Problem**: When prefill token count is not an exact multiple of `block_size`, decode outputs garbage.
**Root Cause**: `_chunked_decode_attention` calculated `last_block_valid_tokens` using `len(seq) - 1`, which increases during decode. But CPU blocks are fixed after prefill!
```python
# BUG: len(seq) increases each decode step
total_prefill_tokens = len(seq) - 1 # Wrong!
last_block_valid_tokens = total_prefill_tokens % block_size # Reads garbage from CPU
```
**Fix**: Cache original prefill length in `HybridKVCacheManager.get_prefill_len()`:
```python
# CORRECT: Use cached prefill length
total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Fixed value
```
**Files Modified**:
- `nanovllm/kvcache/hybrid_manager.py`: Added `_prefill_len` dict and `get_prefill_len()` method
- `nanovllm/layers/attention.py`: Use `get_prefill_len()` instead of `len(seq) - 1`
### Block Size 4096 Race Condition (PENDING)
**Problem**: `block_size=4096` with multiple chunks produces garbled output. `block_size=1024` works correctly.
**Symptoms**:
- `CUDA_LAUNCH_BLOCKING=1` makes tests pass (confirms race condition)
- `torch.cuda.synchronize()` before `store_kvcache` fixes it (heavy-handed)
- Issue specific to larger block sizes with multiple chunks
**Current Workaround**: Default `block_size` changed from 4096 to 1024.
**Root Cause**: Suspected race between `compute_stream`, `transfer_stream_main`, and per-slot streams during layer-by-layer offload. Investigation ongoing.
## Configuration
| Parameter | Default | Notes |
|-----------|---------|-------|
| `kvcache_block_size` | 1024 | Tokens per block |
| `kvcache_block_size` | 1024 | Tokens per block (changed from 4096 due to race condition) |
| `max_num_batched_tokens` | 16384 | Set = max_model_len for long context |
| `gpu_memory_utilization` | 0.9 | GPU memory fraction |
| `enable_cpu_offload` | False | Enable for long context |

View File

@@ -62,6 +62,8 @@ class LLMEngine:
token_ids = self.model_runner.call("run", seqs, is_prefill)
self.scheduler.postprocess(seqs, token_ids)
outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished]
#> Calculate number of tokens processed
num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs)
return outputs, num_tokens

View File

@@ -128,6 +128,9 @@ class HybridKVCacheManager(KVCacheManager):
self.cpu_block_to_logical: Dict[int, int] = {} # cpu_block -> logical_id
# Prefix cache (uses logical block IDs)
# NOTE: Currently WRITE-ONLY in offload mode - hashes are stored but never
#> used for cache hit detection. This is intentional: offload mode always
#> allocates new blocks and doesn't reuse existing ones.
self.hash_to_logical_id: Dict[int, int] = {}
# Step counter for policy
@@ -258,14 +261,10 @@ class HybridKVCacheManager(KVCacheManager):
pos_in_block = seq_len % self._block_size
if pos_in_block == 1:
# Need new block
assert last_block.hash != -1
# Need new block (previous block is full)
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.hash = -1
block.token_ids = []
# Allocate new block to CPU (ring buffer mode)
if not self.free_cpu_blocks:
@@ -279,17 +278,13 @@ class HybridKVCacheManager(KVCacheManager):
block_table.append(logical_id)
elif pos_in_block == 0:
# Block is full, update hash for prefix cache
assert last_block.hash == -1
token_ids = seq.block(seq.num_blocks - 1)
prefix_hash = (
self.logical_blocks[block_table[-2]].hash
if len(block_table) > 1 else -1
)
h = self.compute_hash(token_ids, prefix_hash)
last_block.hash = h
last_block.token_ids = token_ids.copy()
self.hash_to_logical_id[h] = last_logical_id
# Block is full
# NOTE: Prefix cache disabled in offload mode
# If enabled, would compute hash and update:
# h = self.compute_hash(seq.block(seq.num_blocks - 1), prefix_hash)
# last_block.hash = h
# self.hash_to_logical_id[h] = last_logical_id
pass
def prepare_for_attention(
self,
@@ -369,8 +364,6 @@ class HybridKVCacheManager(KVCacheManager):
"""
assert not seq.block_table, "Sequence already has blocks"
h = -1 # Running hash for prefix cache
for i in range(seq.num_blocks):
# Allocate CPU block
if not self.free_cpu_blocks:
@@ -381,19 +374,10 @@ class HybridKVCacheManager(KVCacheManager):
cpu_block_id = self.free_cpu_blocks.popleft()
# Get token IDs for this block and compute hash
token_ids = seq.block(i)
if len(token_ids) == self._block_size:
h = self.compute_hash(token_ids, h)
else:
h = -1 # Incomplete block
# Allocate logical block
logical_id = self.free_logical_ids.popleft()
block = self.logical_blocks[logical_id]
block.ref_count = 1
block.hash = h
block.token_ids = token_ids.copy() if len(token_ids) == self._block_size else []
block.location = BlockLocation.CPU
block.cpu_block_id = cpu_block_id
block.gpu_slot = -1
@@ -401,9 +385,11 @@ class HybridKVCacheManager(KVCacheManager):
self.cpu_block_to_logical[cpu_block_id] = logical_id
seq.block_table.append(logical_id)
# Update prefix cache
if h != -1:
self.hash_to_logical_id[h] = logical_id
# NOTE: Prefix cache disabled in offload mode
# If enabled, would compute hash and update:
# h = self.compute_hash(seq.block(i), prefix_hash)
# block.hash = h
# self.hash_to_logical_id[h] = logical_id
def get_cpu_block_table(self, seq: Sequence) -> List[int]:
"""

View File

@@ -1,23 +0,0 @@
cmake_minimum_required(VERSION 3.18)
project(sgdma_test CUDA CXX)
# Find CUDA
enable_language(CUDA)
find_package(CUDA REQUIRED)
# Set C++ standard
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
# CUDA flags
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 --use_fast_math")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
# Build test executable
add_executable(sgdma_test sgdma_test.cpp)
target_link_libraries(sgdma_test cudart)
# Set output directory
set_target_properties(sgdma_test PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
)

View File

@@ -1,326 +0,0 @@
#include <cuda_runtime.h>
#include <iostream>
#include <chrono>
#include <cstring>
#include <cstdlib>
#include <iomanip>
// CUDA error checking macro
#define CUDA_CHECK(call) do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
std::cerr << "CUDA Error in " << __FILE__ << " at line " << __LINE__ << ": " \
<< cudaGetErrorString(err) << std::endl; \
exit(EXIT_FAILURE); \
} \
} while (0)
// Configuration matching nano-vllm realistic parameters
struct Config {
int num_layers = 32;
int num_blocks = 10; // Reduced from 100 to avoid huge allocation
int block_size = 4096;
int num_kv_heads = 8;
int head_dim = 128;
int dtype_size = 2; // float16
// Derived parameters (use size_t to avoid overflow)
size_t features_per_block() const { return (size_t)block_size * num_kv_heads * head_dim; }
size_t bytes_per_block() const { return features_per_block() * dtype_size; }
int total_blocks_per_layer() const { return num_blocks; }
size_t bytes_per_layer() const { return (size_t)num_blocks * bytes_per_block(); }
size_t total_bytes() const { return (size_t)num_layers * bytes_per_layer(); }
};
// Timer utility
class Timer {
std::chrono::high_resolution_clock::time_point start_time;
public:
void start() { start_time = std::chrono::high_resolution_clock::now(); }
double elapsed_ms() {
auto end = std::chrono::high_resolution_clock::now();
return std::chrono::duration<double, std::milli>(end - start_time).count();
}
};
// Initialize CPU memory with test pattern
void init_test_data(void* data, size_t bytes, int seed) {
uint16_t* ptr = static_cast<uint16_t*>(data);
size_t num_elements = bytes / sizeof(uint16_t);
for (size_t i = 0; i < num_elements; i++) {
ptr[i] = static_cast<uint16_t>((seed + i) % 65536);
}
}
// Verify data correctness
bool verify_data(const void* data1, const void* data2, size_t bytes) {
const uint16_t* p1 = static_cast<const uint16_t*>(data1);
const uint16_t* p2 = static_cast<const uint16_t*>(data2);
size_t num_elements = bytes / sizeof(uint16_t);
for (size_t i = 0; i < num_elements; i++) {
if (p1[i] != p2[i]) {
std::cerr << "Mismatch at element " << i << ": "
<< p1[i] << " != " << p2[i] << std::endl;
return false;
}
}
return true;
}
// ============================================================
// Test 1: Basic Functionality Test
// ============================================================
bool test_basic_functionality(const Config& cfg) {
std::cout << "\n[Test 1] Basic Functionality Test" << std::endl;
std::cout << " Testing cudaMemcpy2D correctness with strided layout" << std::endl;
// Allocate strided CPU memory (pinned)
// Layout: [num_layers, num_blocks, block_features]
size_t total_bytes = cfg.total_bytes();
std::cout << " Allocating " << total_bytes / 1024.0 / 1024.0 / 1024.0 << " GB pinned memory..." << std::endl;
void* cpu_strided = nullptr;
CUDA_CHECK(cudaMallocHost(&cpu_strided, total_bytes));
std::cout << " CPU strided memory allocated at: " << cpu_strided << std::endl;
// Allocate GPU memory for one block (all layers)
size_t gpu_block_bytes = cfg.num_layers * cfg.bytes_per_block();
void* gpu_data = nullptr;
CUDA_CHECK(cudaMalloc(&gpu_data, gpu_block_bytes));
// Allocate CPU verify buffer
void* cpu_verify = nullptr;
CUDA_CHECK(cudaMallocHost(&cpu_verify, gpu_block_bytes));
// Initialize strided CPU memory
init_test_data(cpu_strided, total_bytes, 12345);
// Test: Copy block_id=5 from CPU to GPU using cudaMemcpy2D
int test_block_id = 5;
size_t spitch = cfg.bytes_per_layer(); // Source pitch (stride between layers)
size_t dpitch = cfg.bytes_per_block(); // Destination pitch (contiguous)
size_t width = cfg.bytes_per_block(); // Width to copy per row
size_t height = cfg.num_layers; // Number of rows (layers)
// Debug: print parameters
std::cout << " cudaMemcpy2D parameters:" << std::endl;
std::cout << " spitch: " << spitch << " bytes" << std::endl;
std::cout << " dpitch: " << dpitch << " bytes" << std::endl;
std::cout << " width: " << width << " bytes" << std::endl;
std::cout << " height: " << height << " rows" << std::endl;
std::cout << " dpitch >= width: " << (dpitch >= width ? "yes" : "no") << std::endl;
std::cout << " spitch >= width: " << (spitch >= width ? "yes" : "no") << std::endl;
// Calculate source pointer (first layer, block_id)
uint8_t* src_ptr = static_cast<uint8_t*>(cpu_strided) + test_block_id * cfg.bytes_per_block();
// H2D transfer
CUDA_CHECK(cudaMemcpy2D(
gpu_data, // dst
dpitch, // dpitch
src_ptr, // src
spitch, // spitch
width, // width
height, // height
cudaMemcpyHostToDevice
));
// D2H transfer back
CUDA_CHECK(cudaMemcpy2D(
cpu_verify, // dst
dpitch, // dpitch
gpu_data, // src
dpitch, // spitch
width, // width
height, // height
cudaMemcpyDeviceToHost
));
// Verify correctness
bool passed = true;
for (int layer = 0; layer < cfg.num_layers; layer++) {
uint8_t* expected_ptr = static_cast<uint8_t*>(cpu_strided) +
layer * cfg.bytes_per_layer() +
test_block_id * cfg.bytes_per_block();
uint8_t* actual_ptr = static_cast<uint8_t*>(cpu_verify) +
layer * cfg.bytes_per_block();
if (!verify_data(expected_ptr, actual_ptr, cfg.bytes_per_block())) {
std::cerr << " Verification failed at layer " << layer << std::endl;
passed = false;
break;
}
}
// Cleanup
CUDA_CHECK(cudaFreeHost(cpu_strided));
CUDA_CHECK(cudaFreeHost(cpu_verify));
CUDA_CHECK(cudaFree(gpu_data));
std::cout << " Result: " << (passed ? "PASSED ✓" : "FAILED ✗") << std::endl;
return passed;
}
// ============================================================
// Test 2: Performance Benchmark
// ============================================================
void test_performance_benchmark(const Config& cfg) {
std::cout << "\n[Test 2] Performance Benchmark" << std::endl;
std::cout << " Configuration:" << std::endl;
std::cout << " num_layers: " << cfg.num_layers << std::endl;
std::cout << " num_blocks: " << cfg.num_blocks << std::endl;
std::cout << " block_size: " << cfg.block_size << std::endl;
std::cout << " num_kv_heads: " << cfg.num_kv_heads << std::endl;
std::cout << " head_dim: " << cfg.head_dim << std::endl;
std::cout << " dtype_size: " << cfg.dtype_size << " bytes" << std::endl;
std::cout << " bytes_per_block: " << cfg.bytes_per_block() / 1024.0 << " KB" << std::endl;
std::cout << " total transfer size: " << cfg.num_layers * cfg.bytes_per_block() / 1024.0 / 1024.0 << " MB" << std::endl;
const int num_iterations = 100;
const int warmup = 10;
int test_block_id = 5;
// Allocate memory
size_t total_bytes = cfg.total_bytes();
void* cpu_strided = nullptr;
CUDA_CHECK(cudaMallocHost(&cpu_strided, total_bytes));
void* cpu_contiguous = nullptr;
size_t gpu_block_bytes = cfg.num_layers * cfg.bytes_per_block();
CUDA_CHECK(cudaMallocHost(&cpu_contiguous, gpu_block_bytes));
void* gpu_data = nullptr;
CUDA_CHECK(cudaMalloc(&gpu_data, gpu_block_bytes));
init_test_data(cpu_strided, total_bytes, 12345);
init_test_data(cpu_contiguous, gpu_block_bytes, 12345);
Timer timer;
double elapsed;
double bandwidth;
// ========================================
// Method A: cudaMemcpy2D with strided layout
// ========================================
size_t spitch = cfg.bytes_per_layer();
size_t dpitch = cfg.bytes_per_block();
size_t width = cfg.bytes_per_block();
size_t height = cfg.num_layers;
uint8_t* src_ptr = static_cast<uint8_t*>(cpu_strided) + test_block_id * cfg.bytes_per_block();
// Warmup
for (int i = 0; i < warmup; i++) {
CUDA_CHECK(cudaMemcpy2D(gpu_data, dpitch, src_ptr, spitch, width, height, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
// Benchmark
timer.start();
for (int i = 0; i < num_iterations; i++) {
CUDA_CHECK(cudaMemcpy2D(gpu_data, dpitch, src_ptr, spitch, width, height, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
elapsed = timer.elapsed_ms();
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
std::cout << "\n Method A (cudaMemcpy2D strided):" << std::endl;
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
double method_a_bw = bandwidth;
// ========================================
// Method B: cudaMemcpy with contiguous layout (baseline)
// ========================================
// Warmup
for (int i = 0; i < warmup; i++) {
CUDA_CHECK(cudaMemcpy(gpu_data, cpu_contiguous, gpu_block_bytes, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
// Benchmark
timer.start();
for (int i = 0; i < num_iterations; i++) {
CUDA_CHECK(cudaMemcpy(gpu_data, cpu_contiguous, gpu_block_bytes, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
elapsed = timer.elapsed_ms();
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
std::cout << "\n Method B (cudaMemcpy contiguous):" << std::endl;
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
double method_b_bw = bandwidth;
// ========================================
// Method C: Layer-by-layer copy (simulate PyTorch non-contiguous)
// ========================================
// Warmup
for (int i = 0; i < warmup; i++) {
for (int layer = 0; layer < cfg.num_layers; layer++) {
uint8_t* src_layer = static_cast<uint8_t*>(cpu_strided) +
layer * cfg.bytes_per_layer() +
test_block_id * cfg.bytes_per_block();
uint8_t* dst_layer = static_cast<uint8_t*>(gpu_data) + layer * cfg.bytes_per_block();
CUDA_CHECK(cudaMemcpy(dst_layer, src_layer, cfg.bytes_per_block(), cudaMemcpyHostToDevice));
}
}
CUDA_CHECK(cudaDeviceSynchronize());
// Benchmark
timer.start();
for (int i = 0; i < num_iterations; i++) {
for (int layer = 0; layer < cfg.num_layers; layer++) {
uint8_t* src_layer = static_cast<uint8_t*>(cpu_strided) +
layer * cfg.bytes_per_layer() +
test_block_id * cfg.bytes_per_block();
uint8_t* dst_layer = static_cast<uint8_t*>(gpu_data) + layer * cfg.bytes_per_block();
CUDA_CHECK(cudaMemcpy(dst_layer, src_layer, cfg.bytes_per_block(), cudaMemcpyHostToDevice));
}
}
CUDA_CHECK(cudaDeviceSynchronize());
elapsed = timer.elapsed_ms();
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
std::cout << "\n Method C (layer-by-layer copy):" << std::endl;
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
double method_c_bw = bandwidth;
// Summary
std::cout << "\n ========================================" << std::endl;
std::cout << " Performance Summary:" << std::endl;
std::cout << " Method A vs Method B: " << std::setprecision(2) << (method_a_bw / method_b_bw * 100) << "%" << std::endl;
std::cout << " Method A vs Method C: " << std::setprecision(2) << (method_a_bw / method_c_bw) << "x speedup" << std::endl;
std::cout << " ========================================" << std::endl;
// Cleanup
CUDA_CHECK(cudaFreeHost(cpu_strided));
CUDA_CHECK(cudaFreeHost(cpu_contiguous));
CUDA_CHECK(cudaFree(gpu_data));
}
int main() {
std::cout << "=== cudaMemcpy2D Test ===" << std::endl;
// Print CUDA device info
int device;
CUDA_CHECK(cudaGetDevice(&device));
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
std::cout << "Using GPU: " << prop.name << std::endl;
std::cout << "Memory Clock Rate: " << prop.memoryClockRate / 1000 << " MHz" << std::endl;
std::cout << "Memory Bus Width: " << prop.memoryBusWidth << " bits" << std::endl;
std::cout << "Peak Memory Bandwidth: " <<
2.0 * prop.memoryClockRate * (prop.memoryBusWidth / 8) / 1.0e6 << " GB/s" << std::endl;
Config cfg;
// Run tests
bool test1_passed = test_basic_functionality(cfg);
test_performance_benchmark(cfg);
std::cout << "\n=== Test Complete ===" << std::endl;
std::cout << "All tests " << (test1_passed ? "PASSED ✓" : "FAILED ✗") << std::endl;
return test1_passed ? 0 : 1;
}

View File

@@ -1,365 +0,0 @@
"""
Test alignment between nanovllm and custom torch Qwen3 implementation.
Compares attention layer outputs and QKV tensors to verify correctness.
Usage:
python test_align.py # Without CPU offload
python test_align.py --enable-offload # With CPU offload
python test_align.py --input-len 4096 # Custom input length
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import argparse
import torch
from transformers import AutoTokenizer
from nanovllm import LLM, SamplingParams
from modeling_qwen3 import Qwen3ForCausalLM
from utils import generate_needle_prompt
# Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument("--enable-offload", action="store_true", help="Enable CPU offload")
parser.add_argument("--input-len", type=int, default=1024 * 12, help="Input sequence length")
parser.add_argument("--model-path", type=str, default="~/models/Qwen3-0.6B/", help="Model path")
parser.add_argument("--num-gpu-blocks", type=int, default=6, help="Number of GPU blocks (ring buffer slots)")
parser.add_argument("--block-size", type=int, default=1024, help="KV cache block size")
args = parser.parse_args()
# Config
MODEL_PATH = os.path.expanduser(args.model_path)
INPUT_LEN = args.input_len
ENABLE_OFFLOAD = args.enable_offload
NUM_GPU_BLOCKS = args.num_gpu_blocks
BLOCK_SIZE = args.block_size
DTYPE = torch.float16
print(f"Config: input_len={INPUT_LEN}, enable_offload={ENABLE_OFFLOAD}, num_gpu_blocks={NUM_GPU_BLOCKS}, block_size={BLOCK_SIZE}")
# Storage for captured tensors
nanovllm_outputs = {}
torch_outputs = {}
nanovllm_qkv = {}
nanovllm_proj_inputs = {}
torch_proj_inputs = {}
# ============================================================
# Hook functions for non-offload mode (overwrite)
# ============================================================
def make_nanovllm_hook(layer_id: int, storage: dict):
def hook(module, inputs, output):
attn_output = output[0] if isinstance(output, tuple) else output
if attn_output.dim() == 2:
attn_output = attn_output.unsqueeze(0)
storage[layer_id] = attn_output.detach().clone()
return hook
def make_nanovllm_qkv_hook(layer_id: int, storage: dict):
def hook(module, inputs):
q, k, v = inputs[0], inputs[1], inputs[2]
storage[layer_id] = {
"q": q.detach().clone(),
"k": k.detach().clone(),
"v": v.detach().clone(),
}
return hook
def make_proj_input_hook(layer_id: int, storage: dict):
def hook(module, inputs):
hidden = inputs[0]
if hidden.dim() == 2:
hidden = hidden.unsqueeze(0)
storage[layer_id] = hidden.detach().clone()
return hook
# ============================================================
# Hook functions for offload mode (accumulate Q and I, overwrite O)
# ============================================================
def make_accumulating_q_hook(layer_id: int, storage: dict):
"""Accumulate Q from each chunk for offload mode."""
def hook(module, inputs):
q = inputs[0].detach().clone()
if layer_id not in storage:
storage[layer_id] = []
storage[layer_id].append(q)
return hook
def make_accumulating_input_hook(layer_id: int, storage: dict):
"""Accumulate input hidden states from each chunk for offload mode."""
def hook(module, inputs):
hidden = inputs[0].detach().clone()
if layer_id not in storage:
storage[layer_id] = []
storage[layer_id].append(hidden)
return hook
def make_overwrite_output_hook(layer_id: int, storage: dict):
"""Overwrite output (keep only last chunk) for offload mode."""
def hook(module, inputs, output):
attn_output = output[0] if isinstance(output, tuple) else output
if attn_output.dim() == 2:
attn_output = attn_output.unsqueeze(0)
storage[layer_id] = attn_output.detach().clone()
return hook
# ============================================================
# CPU KV cache access for offload mode
# ============================================================
def get_nanovllm_kv_from_cpu(llm, seq, num_layers):
"""Get complete K, V cache from CPU side after all chunks finish."""
offload_engine = llm.model_runner.kvcache_manager.offload_engine
kvcache_manager = llm.model_runner.kvcache_manager
# CRITICAL: Synchronize all CUDA operations before reading CPU memory
# The D2H copy runs on transfer_stream_main and may still be in progress
torch.cuda.synchronize()
cpu_block_ids = kvcache_manager.get_cpu_block_table(seq)
kv_per_layer = {}
for layer_id in range(num_layers):
k_blocks = []
v_blocks = []
for cpu_block_id in cpu_block_ids:
k_block, v_block = offload_engine.get_cpu_block(layer_id, cpu_block_id)
k_blocks.append(k_block)
v_blocks.append(v_block)
# Concatenate all blocks: [total_tokens, kv_heads, head_dim]
k_full = torch.cat(k_blocks, dim=0)[:seq.num_tokens]
v_full = torch.cat(v_blocks, dim=0)[:seq.num_tokens]
kv_per_layer[layer_id] = {"k": k_full, "v": v_full}
return kv_per_layer
def make_torch_hook(layer_id: int, storage: dict):
def hook(module, inputs, output):
storage[layer_id] = output[0].detach().clone()
return hook
def cosine_sim(t1: torch.Tensor, t2: torch.Tensor) -> float:
"""Cosine similarity between flattened tensors (1.0 = identical)."""
return torch.nn.functional.cosine_similarity(
t1.flatten().float(), t2.flatten().float(), dim=0
).item()
def compute_qkv_sims(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int):
"""Compute Q, K, V cosine similarities. Returns (q_sim, k_sim, v_sim)."""
nano_q = nano_qkv["q"]
torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1)
nano_k = nano_qkv["k"]
torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
nano_v = nano_qkv["v"]
torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
return cosine_sim(nano_q, torch_q), cosine_sim(nano_k, torch_k), cosine_sim(nano_v, torch_v)
# ============================================================
# Load models
# ============================================================
print("Loading nanovllm model...")
llm_kwargs = dict(
enforce_eager=True,
max_model_len=32768,
gpu_memory_utilization=0.2,
max_num_batched_tokens=32768,
enable_cpu_offload=ENABLE_OFFLOAD,
dtype="float16",
kvcache_block_size=BLOCK_SIZE,
)
if ENABLE_OFFLOAD:
llm_kwargs["num_gpu_blocks"] = NUM_GPU_BLOCKS
llm = LLM(MODEL_PATH, **llm_kwargs)
num_heads = llm.model_runner.model.model.layers[0].self_attn.num_heads
num_kv_heads = llm.model_runner.model.model.layers[0].self_attn.num_kv_heads
num_kv_groups = num_heads // num_kv_heads
num_layers = len(llm.model_runner.model.model.layers)
print("Loading torch model...")
torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
torch_model = torch_model.to("cuda")
torch_model.eval()
# ============================================================
# Generate test input
# ============================================================
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
prompt, _ = generate_needle_prompt(tokenizer=tokenizer, target_length=INPUT_LEN, verbose=True)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
print(f"Input shape: {input_ids.shape}")
# ============================================================
# Register hooks
# ============================================================
nanovllm_hooks = []
nanovllm_q_accum = {} # For offload mode: accumulated Q from all chunks
nanovllm_i_accum = {} # For offload mode: accumulated I from all chunks
for layer_idx, layer in enumerate(llm.model_runner.model.model.layers):
if ENABLE_OFFLOAD:
# Offload mode: accumulate Q and I, overwrite O
nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_overwrite_output_hook(layer_idx, nanovllm_outputs)))
nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_accumulating_q_hook(layer_idx, nanovllm_q_accum)))
nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_accumulating_input_hook(layer_idx, nanovllm_i_accum)))
else:
# Non-offload mode: overwrite all
nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_nanovllm_hook(layer_idx, nanovllm_outputs)))
nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_nanovllm_qkv_hook(layer_idx, nanovllm_qkv)))
nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, nanovllm_proj_inputs)))
torch_hooks = []
for layer_idx, layer in enumerate(torch_model.model.layers):
torch_hooks.append(layer.self_attn.register_forward_hook(make_torch_hook(layer_idx, torch_outputs)))
torch_hooks.append(layer.self_attn.q_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, torch_proj_inputs)))
# ============================================================
# Run inference
# ============================================================
print("Running nanovllm inference...")
if ENABLE_OFFLOAD:
# Manual execution to capture KV cache before deallocation
# Use max_tokens=2 so sequence doesn't finish immediately after prefill
llm.add_request(input_ids[0].tolist(), SamplingParams(temperature=0.01, max_tokens=2))
# Run prefill step (this calls run_chunked_offload_prefill internally)
output, num_tokens = llm.step()
print(f"[Offload] Prefill done: {num_tokens} tokens")
# Now seq is in running queue, KV cache is in CPU
seq = llm.scheduler.running[0]
print(f"[Offload] Sequence: {seq}")
# Get KV cache from CPU BEFORE decode step deallocates it
nanovllm_kv_cpu = get_nanovllm_kv_from_cpu(llm, seq, num_layers)
print(f"[Offload] Retrieved KV cache from CPU for {seq.num_tokens} tokens")
# IMPORTANT: Save outputs NOW before decode step overwrites them
# nanovllm_outputs contains prefill attention outputs at this point
nanovllm_outputs_prefill = {k: v.clone() for k, v in nanovllm_outputs.items()}
# Complete remaining steps (decode)
while not llm.is_finished():
llm.step()
# Use prefill outputs for comparison
nanovllm_outputs = nanovllm_outputs_prefill
else:
nanovllm_result = llm.generate([input_ids[0].tolist()], SamplingParams(temperature=0.01, max_tokens=1), use_tqdm=False)
print("Running torch inference...")
with torch.no_grad():
torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers)))
# ============================================================
# Compare using cosine similarity (1.0 = perfect alignment)
# ============================================================
print("\n" + "=" * 70)
print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}")
print("=" * 70)
all_passed = True
threshold = 0.999 # Cosine similarity threshold
for layer_idx in range(num_layers):
if ENABLE_OFFLOAD:
# ============================================================
# Offload mode: use accumulated Q/I and CPU-side K/V
# Only compare prompt tokens (INPUT_LEN), exclude generated tokens
# ============================================================
# I: concatenate accumulated chunks, trim to prompt length
i_chunks = nanovllm_i_accum[layer_idx]
nano_in = torch.cat(i_chunks, dim=0)[:INPUT_LEN]
if nano_in.dim() == 2:
nano_in = nano_in.unsqueeze(0)
torch_in = torch_proj_inputs[layer_idx]
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
torch_in = torch_in.view(nano_in.shape)
i_sim = cosine_sim(nano_in, torch_in)
# Q: concatenate accumulated chunks, trim to prompt length
q_chunks = nanovllm_q_accum[layer_idx]
nano_q = torch.cat(q_chunks, dim=0)[:INPUT_LEN]
torch_q = torch_qkv_outputs[layer_idx]["q"].squeeze(0).transpose(0, 1)
q_sim = cosine_sim(nano_q, torch_q)
# K, V: from CPU cache, trim to prompt length and move to GPU
nano_k = nanovllm_kv_cpu[layer_idx]["k"][:INPUT_LEN].cuda()
torch_k = torch_qkv_outputs[layer_idx]["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
k_sim = cosine_sim(nano_k, torch_k)
nano_v = nanovllm_kv_cpu[layer_idx]["v"][:INPUT_LEN].cuda()
torch_v = torch_qkv_outputs[layer_idx]["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1)
v_sim = cosine_sim(nano_v, torch_v)
# O: compare attention outputs directly
# For single-chunk case (input_len <= block_size), shapes should match
# For multi-chunk case, nano_out is the last chunk only
nano_out = nanovllm_outputs[layer_idx]
torch_out = torch_outputs[layer_idx]
if nano_out.numel() == torch_out.numel():
# Single chunk or shapes match - compare directly
o_sim = cosine_sim(nano_out, torch_out)
else:
# Multi-chunk case: compare last chunk with corresponding torch slice
last_chunk_len = nano_out.shape[1] if nano_out.dim() == 3 else nano_out.shape[0]
torch_out_slice = torch_out[:, -last_chunk_len:, :] if torch_out.dim() == 3 else torch_out[-last_chunk_len:, :]
o_sim = cosine_sim(nano_out, torch_out_slice)
else:
# ============================================================
# Non-offload mode: original logic
# ============================================================
# Input similarity
nano_in = nanovllm_proj_inputs[layer_idx]
torch_in = torch_proj_inputs[layer_idx]
if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel():
torch_in = torch_in.view(nano_in.shape)
i_sim = cosine_sim(nano_in, torch_in)
# QKV similarities
q_sim, k_sim, v_sim = compute_qkv_sims(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups)
# O similarity
nano_out = nanovllm_outputs[layer_idx]
torch_out = torch_outputs[layer_idx]
if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel():
torch_out = torch_out.view(nano_out.shape)
o_sim = cosine_sim(nano_out, torch_out)
# Check pass/fail
passed = all(s >= threshold for s in [i_sim, q_sim, k_sim, v_sim, o_sim])
all_passed = all_passed and passed
status = "" if passed else " *"
print(f"Layer {layer_idx:2d}{status:<3} {i_sim:>10.6f} {q_sim:>10.6f} {k_sim:>10.6f} {v_sim:>10.6f} {o_sim:>10.6f}")
# ============================================================
# Cleanup and result
# ============================================================
for hook in nanovllm_hooks + torch_hooks:
hook.remove()
print("=" * 70)
mode_str = " [offload]" if ENABLE_OFFLOAD else ""
if all_passed:
print(f"test_align{mode_str}: PASSED (cosine_sim >= 0.999)")
else:
print(f"test_align{mode_str}: FAILED (* = cosine_sim < 0.999)")

View File

@@ -1,297 +0,0 @@
"""
Test Attention layer with KV cache offload - N-way Pipeline.
This test demonstrates and verifies the N-way pipeline with:
- Per-slot transfer streams for parallel H2D
- Dedicated compute stream (avoids CUDA default stream implicit sync)
- Pre-load phase + main loop with immediate slot reuse
Key difference from previous test:
- We first pre-fill many chunks to CPU cache
- Then simulate processing a new chunk that loads ALL previous blocks
- This exercises the full N-way pipeline with many blocks in flight
"""
import torch
from nanovllm.layers.attention import Attention
from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
from nanovllm.engine.sequence import Sequence
from nanovllm.utils.context import set_context, reset_context
# ============================================================
# Configuration
# ============================================================
NUM_LAYERS = 8
NUM_HEADS = 8
NUM_KV_HEADS = 8
HEAD_DIM = 64
BLOCK_SIZE = 1024
CHUNK_SIZE = 1024
NUM_GPU_SLOTS = 6 # N-way pipeline with 6 slots
NUM_CPU_BLOCKS = 16 # Many blocks to load from CPU
DTYPE = torch.bfloat16
DEVICE = "cuda"
# ============================================================
# Setup
# ============================================================
def create_manager():
manager = HybridKVCacheManager(
num_gpu_slots=NUM_GPU_SLOTS,
num_cpu_blocks=NUM_CPU_BLOCKS,
block_size=BLOCK_SIZE,
)
manager.allocate_cache(
num_layers=NUM_LAYERS,
num_kv_heads=NUM_KV_HEADS,
head_dim=HEAD_DIM,
dtype=DTYPE,
)
return manager
def create_attention_layers(manager):
layers = []
for layer_id in range(NUM_LAYERS):
attn = Attention(
num_heads=NUM_HEADS,
head_dim=HEAD_DIM,
scale=HEAD_DIM ** -0.5,
num_kv_heads=NUM_KV_HEADS,
)
attn.layer_id = layer_id
k_cache, v_cache = manager.get_layer_cache(layer_id)
attn.k_cache = k_cache
attn.v_cache = v_cache
layers.append(attn.to(DEVICE))
return layers
# ============================================================
# Pre-fill CPU cache with random data
# ============================================================
def prefill_cpu_cache(manager, num_blocks):
"""
Fill CPU cache with random KV data for num_blocks blocks.
This simulates having already processed many chunks.
"""
offload_engine = manager.offload_engine
for block_id in range(num_blocks):
# Generate random KV data for all layers
for layer_id in range(NUM_LAYERS):
k_data = torch.randn(
BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM,
dtype=DTYPE, device=DEVICE
)
v_data = torch.randn(
BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM,
dtype=DTYPE, device=DEVICE
)
# Copy to CPU cache
offload_engine.k_cache_cpu[layer_id, block_id].copy_(k_data)
offload_engine.v_cache_cpu[layer_id, block_id].copy_(v_data)
return list(range(num_blocks))
# ============================================================
# Simulate N-way Pipeline (mirrors attention.py logic)
# ============================================================
def simulate_nway_pipeline(
layer_id: int,
q_batched: torch.Tensor,
cpu_block_table: list,
load_slots: list,
offload_engine,
scale: float,
):
"""
Simulate N-way pipeline for a single layer.
This mirrors the logic in Attention._ring_buffer_pipeline_load().
"""
num_blocks = len(cpu_block_table)
num_slots = len(load_slots)
o_acc, lse_acc = None, None
# Phase 1: Pre-load up to num_slots blocks
num_preload = min(num_slots, num_blocks)
torch.cuda.nvtx.range_push(f"Phase1_Preload: L{layer_id}")
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
torch.cuda.nvtx.range_pop()
# Phase 2: Main loop with compute_stream
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
torch.cuda.nvtx.range_push(f"Block: L{layer_id} B{block_idx}")
current_slot = load_slots[block_idx % num_slots]
# Wait for transfer
offload_engine.wait_slot_layer(current_slot, layer_id)
# Compute on dedicated stream
with torch.cuda.stream(compute_stream):
torch.cuda.nvtx.range_push(f"FlashAttn: L{layer_id} B{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=scale,
causal=False,
)
torch.cuda.nvtx.range_pop()
offload_engine.record_slot_compute_done(current_slot, layer_id)
# Start next transfer (reuse current_slot)
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(
current_slot, layer_id, cpu_block_table[next_block_idx]
)
# Merge
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
torch.cuda.nvtx.range_pop()
return o_acc, lse_acc
def simulate_full_forward(layers, manager, cpu_block_table, chunk_size):
"""
Simulate forward pass through all layers, loading previous blocks from CPU.
This is the key test: many blocks loaded via N-way pipeline.
"""
offload_engine = manager.offload_engine
# Current chunk index (we're processing the "next" chunk after all prefilled ones)
current_chunk_idx = len(cpu_block_table)
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
# Random query for attention
q = torch.randn(1, chunk_size, NUM_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE)
outputs = []
for layer in layers:
torch.cuda.nvtx.range_push(f"Layer: {layer.layer_id}")
o_acc, lse_acc = simulate_nway_pipeline(
layer.layer_id,
q,
cpu_block_table,
load_slots,
offload_engine,
layer.scale,
)
outputs.append(o_acc)
torch.cuda.nvtx.range_pop()
return outputs
# ============================================================
# Main Test
# ============================================================
print("=" * 60)
print("Test: N-way Pipeline with CPU Offload")
print("=" * 60)
# 1. Setup
print("\n[1] Creating manager and attention layers...")
manager = create_manager()
layers = create_attention_layers(manager)
offload_engine = manager.offload_engine
print(f" - GPU slots: {NUM_GPU_SLOTS}")
print(f" - CPU blocks: {NUM_CPU_BLOCKS}")
print(f" - Per-slot streams: {len(offload_engine.slot_transfer_streams)}")
print(f" - Compute stream: {offload_engine.compute_stream}")
# 2. Pre-fill CPU cache
NUM_PREV_BLOCKS = 12 # Many blocks to load via N-way pipeline
print(f"\n[2] Pre-filling {NUM_PREV_BLOCKS} blocks to CPU cache...")
cpu_block_table = prefill_cpu_cache(manager, NUM_PREV_BLOCKS)
print(f" - CPU blocks filled: {cpu_block_table}")
# 3. Verify pipeline configuration
current_chunk_idx = NUM_PREV_BLOCKS
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
print(f"\n[3] Pipeline configuration for chunk {current_chunk_idx}:")
print(f" - Write slot: {write_slot}")
print(f" - Load slots: {load_slots}")
print(f" - Pipeline depth (N-way): {len(load_slots)}")
assert len(load_slots) == NUM_GPU_SLOTS - 1, f"Expected {NUM_GPU_SLOTS - 1} load slots"
# 4. Warmup
print("\n[4] Warmup (3 iterations)...")
for i in range(3):
outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE)
torch.cuda.synchronize()
print(f" - Warmup {i+1}/3 done")
# 5. Benchmark
NUM_ITERS = 10
print(f"\n[5] Benchmark ({NUM_ITERS} iterations)...")
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"Iteration_{i}")
outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE)
torch.cuda.nvtx.range_pop()
end_event.record()
torch.cuda.synchronize()
elapsed_ms = start_event.elapsed_time(end_event)
# Stats
total_blocks_loaded = NUM_PREV_BLOCKS * NUM_LAYERS * NUM_ITERS
blocks_per_sec = total_blocks_loaded / (elapsed_ms / 1000)
total_tokens = NUM_PREV_BLOCKS * BLOCK_SIZE * NUM_LAYERS * NUM_ITERS
tokens_per_sec = total_tokens / (elapsed_ms / 1000)
print(f"\n[6] Results:")
print(f" - Total time: {elapsed_ms:.2f} ms")
print(f" - Per iteration: {elapsed_ms / NUM_ITERS:.2f} ms")
print(f" - Blocks loaded: {total_blocks_loaded} ({blocks_per_sec:.0f} blocks/s)")
print(f" - Tokens processed: {total_tokens} ({tokens_per_sec:.0f} tok/s)")
# 7. Verification
print("\n[7] Verification:")
assert len(outputs) == NUM_LAYERS, f"Expected {NUM_LAYERS} outputs"
for i, o in enumerate(outputs):
assert o is not None, f"Layer {i} output is None"
assert o.shape == (1, CHUNK_SIZE, NUM_HEADS, HEAD_DIM), f"Layer {i} shape mismatch"
print(" - All layer outputs valid ✓")
print(" - N-way pipeline executed correctly ✓")
# Cleanup
reset_context()
print("\n" + "=" * 60)
print("test_attention_offload: PASSED")
print("=" * 60)

View File

@@ -1,169 +0,0 @@
"""
Test script for chunked attention correctness.
Validates that chunked prefill using flash_attn_with_lse + merge_attention_outputs
produces the same result as full flash_attn_varlen_func.
Scenario: Simulating chunked prefill where we process query chunk by chunk.
For each query chunk i:
- KV contains all tokens from chunk 0 to chunk i
- Previous KV chunks (0 to i-1): full attention (no causal mask)
- Current KV chunk (i): causal attention (diagonal block)
"""
import torch
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_func
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# ============================================================
# Utility Functions
# ============================================================
def compute_chunked_prefill_for_chunk(
q_chunk: torch.Tensor,
kv_chunks: list,
current_chunk_idx: int,
) -> torch.Tensor:
"""
Compute attention for a single query chunk against all KV chunks up to current.
This simulates chunked prefill for query chunk `current_chunk_idx`:
- KV chunks 0 to current_chunk_idx-1: full attention (all previous tokens visible)
- KV chunk current_chunk_idx: causal attention (diagonal block)
Args:
q_chunk: [batch, chunk_size, nheads, headdim] - current query chunk
kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim]
current_chunk_idx: Index of the current chunk being processed
Returns:
out: [batch, chunk_size, nheads, headdim]
"""
accumulated_o = None
accumulated_lse = None
for i in range(current_chunk_idx + 1):
k_chunk, v_chunk = kv_chunks[i]
# Previous chunks: no causal mask (all tokens visible)
# Current chunk (diagonal): causal mask
is_diagonal = (i == current_chunk_idx)
chunk_o, chunk_lse = flash_attn_with_lse(
q_chunk, k_chunk, v_chunk, causal=is_diagonal
)
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse
)
return accumulated_o
def compute_reference_causal(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
"""
Compute reference causal attention using flash_attn_func.
Args:
q, k, v: [batch, seqlen, nheads, headdim]
Returns:
out: [batch, seqlen, nheads, headdim]
"""
return flash_attn_func(q, k, v, causal=True)
# ============================================================
# Main Test Script
# ============================================================
torch.manual_seed(42)
# Test configurations: (batch, num_chunks, chunk_size, nheads, headdim)
TEST_CASES = [
(1, 4, 256, 8, 128),
(1, 4, 512, 8, 128),
(1, 8, 512, 8, 128),
(1, 32, 1024, 8, 128),
(1, 32, 1024, 32, 128), # More heads
(1, 32, 256, 8, 64), # Smaller head dim
]
DTYPES = [torch.float16, torch.bfloat16]
print("=" * 80)
print("Test: Chunked Prefill Attention vs Reference (flash_attn_func causal)")
print("=" * 80)
print("Simulating chunked prefill: Q chunk attends to all KV chunks up to current")
print(" - Previous KV chunks: full attention (no causal mask)")
print(" - Current KV chunk (diagonal): causal attention")
print()
all_passed = True
for dtype in DTYPES:
print(f"--- dtype: {dtype} ---")
for batch, num_chunks, chunk_size, nheads, headdim in TEST_CASES:
seqlen = num_chunks * chunk_size
# Generate full Q, K, V
q_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
k_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
v_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
# Reference: full causal attention
out_ref = compute_reference_causal(q_full, k_full, v_full)
# Split into chunks
q_chunks = [q_full[:, i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)]
kv_chunks = [
(k_full[:, i*chunk_size:(i+1)*chunk_size],
v_full[:, i*chunk_size:(i+1)*chunk_size])
for i in range(num_chunks)
]
# Compute chunked prefill for each query chunk
out_chunks = []
for chunk_idx in range(num_chunks):
chunk_out = compute_chunked_prefill_for_chunk(
q_chunks[chunk_idx],
kv_chunks,
chunk_idx,
)
out_chunks.append(chunk_out)
# Concatenate chunked outputs
out_chunked = torch.cat(out_chunks, dim=1)
# Compare
diff = (out_ref - out_chunked).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
# Tolerance: fp16/bf16 have limited precision
tol = 1e-2
passed = max_diff < tol
all_passed = all_passed and passed
status = "PASS" if passed else "FAIL"
print(
f"[{status}] seqlen={seqlen:5d} chunks={num_chunks} "
f"chunk_size={chunk_size:4d} heads={nheads:2d} dim={headdim:3d} "
f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}"
)
print()
print("=" * 80)
assert all_passed, "Some tests failed!"
print("test_chunked_attention: PASSED")

View File

@@ -1,391 +0,0 @@
"""
Correctness test for chunked decode attention.
Captures Q and output during inference, then computes reference using
CPU KV cache with standard flash attention.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch
from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from flash_attn.flash_attn_interface import flash_attn_func
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 128 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024
NUM_DECODE_TOKENS = 5
BLOCK_SIZE = 1024
# State
prefill_captures: List[Dict] = []
decode_captures: List[Dict] = []
def make_ones_injection_hook():
"""Inject Q=K=V=1.0 for deterministic testing."""
def hook(module, inputs):
q, k, v = inputs[0], inputs[1], inputs[2]
q_ones = torch.ones_like(q)
k_ones = torch.ones_like(k)
v_ones = torch.ones_like(v)
return (q_ones, k_ones, v_ones) + inputs[3:]
return hook
def make_capture_hook(layer_id: int):
"""Capture Q, K, V, output during inference."""
def hook(module, inputs, output):
ctx = get_context()
q, k, v = inputs
if ctx.is_prefill:
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
prefill_captures.append({
'layer_id': layer_id,
'chunk_idx': chunk_idx,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
else:
decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id])
decode_captures.append({
'layer_id': layer_id,
'decode_step': decode_step,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
return hook
def compute_decode_reference(layer_id: int, decode_step: int, scale: float,
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
block_size: int, num_prefill_chunks: int) -> torch.Tensor:
"""
Compute reference decode output using CPU KV cache and standard flash attention.
For decode, query attends to:
1. All prefill KV (from CPU cache)
2. All previous decode tokens (from captured decode k, v)
"""
# Get decode capture for this layer and step
decode_cap = None
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
decode_cap = c
break
if decode_cap is None:
return None
# Query: single decode token
q = decode_cap['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
# Collect all K, V: prefill chunks from captures + decode tokens from captures
# NOTE: We use prefill captures directly instead of CPU cache because
# the CPU block ID may not equal the chunk index.
all_k = []
all_v = []
# 1. Prefill chunks from captures (use captured K/V, not CPU cache)
for cidx in range(num_prefill_chunks):
prefill_cap = None
for c in prefill_captures:
if c['layer_id'] == layer_id and c['chunk_idx'] == cidx:
prefill_cap = c
break
if prefill_cap is not None:
# Use captured K/V directly (guaranteed to be correct layer data)
all_k.append(prefill_cap['k'].cuda())
all_v.append(prefill_cap['v'].cuda())
# 2. Decode tokens from captures (up to and including current step)
for step in range(decode_step + 1):
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
all_k.append(c['k'].cuda())
all_v.append(c['v'].cuda())
break
if not all_k:
return None
# Concatenate all K, V
full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim]
full_v = torch.cat(all_v, dim=0).unsqueeze(0)
# Run flash attention (non-causal since we explicitly control what KV to include)
output = flash_attn_func(
q_batched, full_k, full_v,
softmax_scale=scale,
causal=False,
)
return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim]
# ============================================================
# Main
# ============================================================
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
# Register hooks
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
# Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Run inference
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Get CPU cache reference
offload_engine = llm.model_runner.kvcache_manager.offload_engine
k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Calculate number of prefill chunks
num_prefill_chunks = INPUT_LEN // BLOCK_SIZE
# Debug: Compare decode_buffer with captured K/V
print("\n=== DEBUG: Comparing decode_buffer with captured K/V ===")
decode_k_buffer = offload_engine.decode_k_buffer.clone().cpu()
for step in range(NUM_DECODE_TOKENS):
for layer_id in [0, 17, 35]: # Sample a few layers
# Find captured K for this step and layer
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
captured_k = c['k'].squeeze(0) # [kv_heads, head_dim]
buffer_k = decode_k_buffer[layer_id, step] # [kv_heads, head_dim]
diff = (captured_k - buffer_k).abs().max().item()
print(f"Step {step}, Layer {layer_id}: captured vs buffer max_diff={diff:.6f}")
break
# Debug: Verify that decode_buffer slices match concatenated captures
print("\n=== DEBUG: Verifying decode_buffer slices ===")
for layer_id in [0]:
for decode_step in [1, 2]: # Check steps that use multiple tokens
# Build expected slice from captures
expected_k_list = []
for step in range(decode_step + 1):
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == step:
expected_k_list.append(c['k'].squeeze(0)) # [kv_heads, head_dim]
break
if expected_k_list:
expected_k = torch.stack(expected_k_list, dim=0) # [num_tokens, kv_heads, head_dim]
buffer_slice = decode_k_buffer[layer_id, 0:decode_step+1]
diff = (expected_k - buffer_slice).abs().max().item()
print(f"Decode step {decode_step}, Layer {layer_id}: buffer slice vs expected max_diff={diff:.6f}")
# Print first values
print(f" Buffer[0,0,0]={buffer_slice[0,0,0].item():.6f}, Expected[0,0,0]={expected_k[0,0,0].item():.6f}")
if decode_step >= 1:
print(f" Buffer[1,0,0]={buffer_slice[1,0,0].item():.6f}, Expected[1,0,0]={expected_k[1,0,0].item():.6f}")
# Debug: Print expected K value for block 0, layer 0 (to compare with actual loading)
print("\n=== DEBUG: Expected K values for block 0, layer 0 ===")
for c in prefill_captures:
if c['layer_id'] == 0 and c['chunk_idx'] == 0:
print(f"Captured K[0,0,0] for layer 0, chunk 0: {c['k'][0,0,0].item():.6f}")
break
print(f"CPU cache K[0,0,0,0,0] for layer 0, block 0: {k_cache_cpu[0,0,0,0,0].item():.6f}")
# Debug: Compare CPU cache with captured prefill K/V
print("\n=== DEBUG: Comparing CPU cache with captured prefill K/V ===")
for chunk_idx in [0, 7, 15]: # Sample a few chunks
for layer_id in [0, 17, 35]: # Sample a few layers
# Find captured K for this chunk and layer
for c in prefill_captures:
if c['layer_id'] == layer_id and c['chunk_idx'] == chunk_idx:
captured_k = c['k'] # [seq_len, kv_heads, head_dim]
cpu_cache_k = k_cache_cpu[layer_id, chunk_idx, :captured_k.shape[0]]
diff = (captured_k - cpu_cache_k).abs().max().item()
print(f"Chunk {chunk_idx}, Layer {layer_id}: captured vs CPU cache max_diff={diff:.6f}")
break
# Debug: Get cpu_block_table to check order
kvcache_manager = llm.model_runner.kvcache_manager
# Find the sequence (it should still exist)
from nanovllm.engine.sequence import Sequence
for attr_name in ['sequences', '_sequences', 'active_sequences']:
if hasattr(kvcache_manager, attr_name):
print(f"Found {attr_name}")
break
# Try to get cpu_block_table through a different way
print(f"\n=== DEBUG: CPU block order ===")
# For each prefill capture, check which CPU block it ended up in
for chunk_idx in range(num_prefill_chunks):
for c in prefill_captures:
if c['layer_id'] == 0 and c['chunk_idx'] == chunk_idx:
# Check if this chunk's K matches any CPU block
captured_k_first = c['k'][0, 0, 0].item()
for block_id in range(num_prefill_chunks):
cpu_k_first = k_cache_cpu[0, block_id, 0, 0, 0].item()
if abs(captured_k_first - cpu_k_first) < 1e-6:
print(f"Chunk {chunk_idx} -> CPU block {block_id}")
break
break
# Debug: Check reference vs actual for decode steps 0 and 1
# Also compute partial references (prefill only, decode only) to isolate the bug
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
for decode_step in [0, 1]:
print(f"\n=== DEBUG: Reference vs Actual for layer 0, decode {decode_step} ===")
layer_id = 0
# Find the capture
for c in decode_captures:
if c['layer_id'] == layer_id and c['decode_step'] == decode_step:
q = c['q'].cuda() # [1, num_heads, head_dim]
q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim]
# Build prefill K/V per-block for block-by-block reference
prefill_k_blocks = []
prefill_v_blocks = []
for cidx in range(num_prefill_chunks):
for pc in prefill_captures:
if pc['layer_id'] == layer_id and pc['chunk_idx'] == cidx:
prefill_k_blocks.append(pc['k'].cuda().unsqueeze(0)) # [1, block_size, kv_heads, head_dim]
prefill_v_blocks.append(pc['v'].cuda().unsqueeze(0))
break
# Build decode K/V
decode_k_list = []
decode_v_list = []
for step in range(decode_step + 1):
for dc in decode_captures:
if dc['layer_id'] == layer_id and dc['decode_step'] == step:
decode_k_list.append(dc['k'].cuda())
decode_v_list.append(dc['v'].cuda())
break
full_prefill_k = torch.cat([kb.squeeze(0) for kb in prefill_k_blocks], dim=0).unsqueeze(0)
full_prefill_v = torch.cat([vb.squeeze(0) for vb in prefill_v_blocks], dim=0).unsqueeze(0)
full_decode_k = torch.cat(decode_k_list, dim=0).unsqueeze(0)
full_decode_v = torch.cat(decode_v_list, dim=0).unsqueeze(0)
full_k = torch.cat([full_prefill_k, full_decode_k], dim=1)
full_v = torch.cat([full_prefill_v, full_decode_v], dim=1)
print(f"Q shape: {q_batched.shape}")
print(f"Prefill K shape: {full_prefill_k.shape}")
print(f"Decode K shape: {full_decode_k.shape}")
print(f"Full K shape: {full_k.shape}")
print(f"Total tokens: prefill={num_prefill_chunks * BLOCK_SIZE}, decode={decode_step + 1}")
# Reference output (single attention over all)
ref_output = flash_attn_func(
q_batched, full_k, full_v,
softmax_scale=scale,
causal=False,
)
# Chunked reference: prefill attention + decode attention + merge
prefill_o, prefill_lse = flash_attn_with_lse(
q_batched, full_prefill_k, full_prefill_v,
softmax_scale=scale,
causal=False,
)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, full_decode_k, full_decode_v,
softmax_scale=scale,
causal=False,
)
chunked_output, _ = merge_attention_outputs(prefill_o, prefill_lse, decode_o, decode_lse)
# Block-by-block reference (simulating ring buffer pipeline)
block_o_acc, block_lse_acc = None, None
for bidx, (kb, vb) in enumerate(zip(prefill_k_blocks, prefill_v_blocks)):
o_blk, lse_blk = flash_attn_with_lse(q_batched, kb, vb, softmax_scale=scale, causal=False)
if block_o_acc is None:
block_o_acc, block_lse_acc = o_blk, lse_blk
else:
block_o_acc, block_lse_acc = merge_attention_outputs(block_o_acc, block_lse_acc, o_blk, lse_blk)
# Compare block-by-block vs single
block_vs_single_diff = (block_o_acc - prefill_o).abs().max().item()
print(f"Block-by-block vs Single max_diff: {block_vs_single_diff:.6f}")
# Compare full reference vs chunked reference
ref_vs_chunked_diff = (ref_output - chunked_output).abs().max().item()
print(f"Reference vs Chunked-reference max_diff: {ref_vs_chunked_diff:.6f}")
ref_output = ref_output.squeeze(0).squeeze(0).cpu()
chunked_output_cpu = chunked_output.squeeze(0).squeeze(0).cpu()
# Actual output
actual_output = c['output'].squeeze(0)
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0)
diff_ref = (actual_output - ref_output).abs()
diff_chunked = (actual_output - chunked_output_cpu).abs()
print(f"Actual vs Reference max_diff: {diff_ref.max().item():.6f}")
print(f"Actual vs Chunked-ref max_diff: {diff_chunked.max().item():.6f}")
break
print()
# Verify decode outputs
all_passed = True
for c in decode_captures:
layer_id = c['layer_id']
decode_step = c['decode_step']
ref_output = compute_decode_reference(
layer_id, decode_step, scale,
k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks
)
if ref_output is None:
continue
actual_output = c['output'].squeeze(0)
if actual_output.dim() == 3:
actual_output = actual_output.squeeze(0)
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
passed = max_diff < 1e-1
all_passed = all_passed and passed
if not passed:
print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}")
print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}")

View File

@@ -1,196 +0,0 @@
"""
Correctness test for chunked prefill attention.
Captures Q and output during inference, then computes reference using
CPU KV cache with standard flash attention.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import torch
from random import randint, seed
from typing import Dict, List
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
MAX_MODEL_LEN = 128 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024
BLOCK_SIZE = 1024
# State - capture Q and output for each (layer, chunk)
captures: List[Dict] = []
def make_ones_injection_hook():
"""Inject Q=K=V=1.0 for deterministic testing."""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return inputs
q, k, v = inputs[0], inputs[1], inputs[2]
q_ones = torch.ones_like(q)
k_ones = torch.ones_like(k)
v_ones = torch.ones_like(v)
return (q_ones, k_ones, v_ones) + inputs[3:]
return hook
def make_capture_hook(layer_id: int):
"""Capture Q and output during prefill."""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
q, k, v = inputs
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
captures.append({
'layer_id': layer_id,
'chunk_idx': chunk_idx,
'q': q.clone().cpu(),
'k': k.clone().cpu(),
'v': v.clone().cpu(),
'output': output.clone().cpu(),
})
return hook
def compute_reference(layer_id: int, chunk_idx: int, scale: float,
k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor,
block_size: int) -> torch.Tensor:
"""
Compute reference output using CPU KV cache and standard flash attention.
Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention,
then extracts output for the current chunk.
"""
# Get all captures for this layer up to chunk_idx
layer_captures = [c for c in captures
if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx]
layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx'])
if not layer_captures:
return None
# Collect Q from captures, K/V from CPU cache
all_q = []
all_k = []
all_v = []
chunk_lengths = []
for c in layer_captures:
cidx = c['chunk_idx']
q = c['q'].cuda() # [seqlen, nheads, headdim]
all_q.append(q)
chunk_lengths.append(q.shape[0])
# Get K, V from CPU cache (already offloaded during prefill)
# CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim]
k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda()
all_k.append(k)
all_v.append(v)
# Concatenate
full_q = torch.cat(all_q, dim=0)
full_k = torch.cat(all_k, dim=0)
full_v = torch.cat(all_v, dim=0)
total_len = full_q.shape[0]
# Run standard causal flash attention
cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda')
full_o = flash_attn_varlen_func(
full_q, full_k, full_v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_len,
max_seqlen_k=total_len,
softmax_scale=scale,
causal=True,
)
# Extract output for current chunk
start_pos = sum(chunk_lengths[:-1])
end_pos = sum(chunk_lengths)
return full_o[start_pos:end_pos].cpu()
# ============================================================
# Main
# ============================================================
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# Get model info
num_layers = len(llm.model_runner.model.model.layers)
head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim
scale = head_dim ** -0.5
# Register hooks
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
# Pre-hook: inject all ones for Q, K, V
# pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook())
# hooks.append(pre_hook)
# Post-hook: capture Q, K, V, output
post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx))
hooks.append(post_hook)
# Run inference
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False)
# Remove hooks
for hook in hooks:
hook.remove()
# Get CPU cache reference
offload_engine = llm.model_runner.kvcache_manager.offload_engine
k_cache_cpu = offload_engine.k_cache_cpu.clone()
v_cache_cpu = offload_engine.v_cache_cpu.clone()
# Verify: compare actual output with reference computed from CPU cache
all_passed = True
num_chunks = INPUT_LEN // BLOCK_SIZE
for idx,c in enumerate(captures):
layer_id = c['layer_id']
chunk_idx = c['chunk_idx']
# Skip chunk 0 (no previous KV to load)
if chunk_idx == 0:
continue
ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE)
if ref_output is None:
continue
actual_output = c['output']
diff = (actual_output - ref_output).abs()
max_diff = diff.max().item()
passed = max_diff < 1e-1 # float16 tolerance
all_passed = all_passed and passed
if not passed:
print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}")
__import__('pdb').set_trace()
print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}")

View File

@@ -1,137 +0,0 @@
"""
Test KV cache offload correctness using debug hooks.
Injects distinctive K/V values, verifies loaded tensors match expected patterns.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import inspect
from random import randint, seed
from typing import Dict, List
import torch
from torch import Tensor
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
# Config
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 32 * 1024
NUM_GPU_BLOCKS = 4
INPUT_LEN = 32 * 1024
BLOCK_SIZE = 1024
# State
load_log: List[Dict] = []
current_chunk: List[int] = [0]
def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None:
"""Record loaded tensor values for layer 0."""
if layer_id != 0:
return
# Go up the stack to find kvcache_manager and print k_cache_gpu[*][0,0,0] for all slots
frame = inspect.currentframe()
try:
caller_frame = frame.f_back
if caller_frame is not None:
local_vars = caller_frame.f_locals
if 'self' in local_vars:
self_obj = local_vars['self']
if hasattr(self_obj, 'k_cache_gpu'):
num_slots = self_obj.k_cache_gpu.shape[0]
vals = []
for i in range(num_slots):
v = self_obj.k_cache_gpu[i][0,0,0].item()
if i == slot_idx:
vals.append(f"[{v}]")
else:
vals.append(str(v))
print(f"[DEBUG] k_cache_gpu[0..{num_slots-1}][0,0,0] = [{', '.join(vals)}]")
finally:
del frame
load_log.append({
"chunk_idx": current_chunk[0],
"cpu_block_id": cpu_block_id,
"k_value": k.float().mean().item(),
})
def make_pattern_injection_hook(layer_id):
"""Inject K = chunk_idx + 1, V = -(chunk_idx + 1) for layer 0."""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill or layer_id != 0:
return inputs
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
current_chunk[0] = chunk_idx
if len(inputs) >= 3:
q, k, v = inputs[0], inputs[1], inputs[2]
k_new = torch.full_like(k, float(chunk_idx + 1))
v_new = torch.full_like(v, float(-(chunk_idx + 1)))
return (q, k_new, v_new) + inputs[3:]
return inputs
return hook
def verify() -> bool:
"""Verify blocks loaded in correct order with correct K values."""
chunk_loads: Dict[int, List[tuple]] = {}
for log in load_log:
chunk = log["chunk_idx"]
if chunk not in chunk_loads:
chunk_loads[chunk] = []
chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"]))
for chunk, loads in chunk_loads.items():
expected_blocks = list(range(chunk))
actual_blocks = [b for b, _ in loads]
k_values = [k for _, k in loads]
expected_k = [float(b + 1) for b in expected_blocks]
if actual_blocks != expected_blocks:
return False
if not all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k)):
return False
return True
# Main
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
offload_engine = llm.model_runner.kvcache_manager.offload_engine
offload_engine.enable_debug_mode()
offload_engine.register_debug_hook(debug_load_hook)
hooks = []
for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers):
hooks.append(decoder_layer.self_attn.attn.register_forward_pre_hook(
make_pattern_injection_hook(layer_idx)
))
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1), use_tqdm=False)
for hook in hooks:
hook.remove()
offload_engine.remove_debug_hook(debug_load_hook)
offload_engine.disable_debug_mode()
# Verify
num_chunks = INPUT_LEN // BLOCK_SIZE
expected_loads = num_chunks * (num_chunks - 1) // 2
passed = len(load_log) == expected_loads and verify()
print(f"test_debug_verification: {'PASSED' if passed else 'FAILED'}")

View File

@@ -1,276 +0,0 @@
"""
Test script for flash_attn_with_kvcache based chunked prefill.
Verifies that chunked prefill produces identical results to full attention.
"""
import torch
from flash_attn import flash_attn_func, flash_attn_with_kvcache
def chunk_prefill(q_full, k_full, v_full, k_cache, v_cache, cache_seqlens, chunk_size):
"""
Chunked prefill using flash_attn_with_kvcache.
Args:
q_full, k_full, v_full: [batch, total_seq_len, heads, head_dim]
k_cache, v_cache: [batch, max_seq_len, kv_heads, head_dim]
cache_seqlens: [batch] - current cache lengths
chunk_size: size of each chunk
Returns:
output: [batch, total_seq_len, heads, head_dim]
"""
total_len = q_full.shape[1]
outputs = []
for start in range(0, total_len, chunk_size):
end = min(start + chunk_size, total_len)
q_chunk = q_full[:, start:end]
k_chunk = k_full[:, start:end]
v_chunk = v_full[:, start:end]
out = flash_attn_with_kvcache(
q_chunk,
k_cache,
v_cache,
k=k_chunk,
v=v_chunk,
cache_seqlens=cache_seqlens,
causal=True,
)
outputs.append(out)
cache_seqlens += (end - start)
return torch.cat(outputs, dim=1)
def reference_attention(q, k, v):
"""Standard flash attention as reference."""
return flash_attn_func(q, k, v, causal=True)
def test_chunked_prefill_correctness():
"""Test that chunked prefill matches full attention."""
batch_size = 1
num_heads = 32
num_kv_heads = 8 # GQA
head_dim = 128
max_seq_len = 131072 # 128K
test_configs = [
(1024, 256), # 1K tokens, 256 chunk
(2048, 512), # 2K tokens, 512 chunk
(4096, 1024), # 4K tokens, 1K chunk
(4096, 2048), # 4K tokens, 2K chunk (2 chunks)
(8192, 2048), # 8K tokens, 2K chunk (4 chunks)
(16384, 4096), # 16K tokens, 4K chunk
(32768, 4096), # 32K tokens, 4K chunk
(65536, 8192), # 64K tokens, 8K chunk
(131072, 8192), # 128K tokens, 8K chunk (16 chunks)
]
for seq_len, chunk_size in test_configs:
print(f"\nTesting seq_len={seq_len}, chunk_size={chunk_size}...")
# Generate random input
torch.manual_seed(42)
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
# Expand K/V for non-GQA reference
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
# Reference: full attention
ref_out = reference_attention(q, k_expanded, v_expanded)
# Chunked prefill with KV cache
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
chunked_out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
# Compare
max_diff = (ref_out - chunked_out).abs().max().item()
mean_diff = (ref_out - chunked_out).abs().mean().item()
# Verify cache was filled correctly
assert cache_seqlens[0].item() == seq_len, f"Cache seqlen mismatch: {cache_seqlens[0].item()} != {seq_len}"
# Check K/V cache content
k_cache_diff = (k_cache[:, :seq_len] - k).abs().max().item()
v_cache_diff = (v_cache[:, :seq_len] - v).abs().max().item()
print(f" Output max_diff: {max_diff:.6f}, mean_diff: {mean_diff:.6f}")
print(f" KV cache diff: k={k_cache_diff:.6f}, v={v_cache_diff:.6f}")
# Tolerance for fp16
tolerance = 1e-2
if max_diff < tolerance:
print(f" PASSED")
else:
print(f" FAILED (max_diff {max_diff:.6f} >= {tolerance})")
return False
return True
def test_incremental_decode():
"""Test that decode after chunked prefill works correctly."""
batch_size = 1
num_heads = 32
num_kv_heads = 8
head_dim = 128
max_seq_len = 8192
prefill_len = 2048
chunk_size = 512
num_decode_steps = 10
print(f"\nTesting incremental decode after chunked prefill...")
print(f" Prefill: {prefill_len} tokens, chunk_size={chunk_size}")
print(f" Decode: {num_decode_steps} steps")
torch.manual_seed(42)
# Prefill phase
q_prefill = torch.randn(batch_size, prefill_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
# Run chunked prefill
prefill_out = chunk_prefill(q_prefill, k_prefill, v_prefill,
k_cache, v_cache, cache_seqlens, chunk_size)
print(f" After prefill: cache_seqlens={cache_seqlens[0].item()}")
# Decode phase - one token at a time
for step in range(num_decode_steps):
q_decode = torch.randn(batch_size, 1, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
decode_out = flash_attn_with_kvcache(
q_decode,
k_cache,
v_cache,
k=k_decode,
v=v_decode,
cache_seqlens=cache_seqlens,
causal=True,
)
cache_seqlens += 1
assert decode_out.shape == (batch_size, 1, num_heads, head_dim)
expected_len = prefill_len + num_decode_steps
actual_len = cache_seqlens[0].item()
print(f" After decode: cache_seqlens={actual_len}")
if actual_len == expected_len:
print(f" PASSED")
return True
else:
print(f" FAILED: expected {expected_len}, got {actual_len}")
return False
def test_batch_processing():
"""Test chunked prefill with batch > 1."""
batch_size = 4
num_heads = 32
num_kv_heads = 8
head_dim = 128
max_seq_len = 4096
seq_len = 2048
chunk_size = 512
print(f"\nTesting batch processing (batch_size={batch_size})...")
torch.manual_seed(42)
q = torch.randn(batch_size, seq_len, num_heads, head_dim,
dtype=torch.float16, device='cuda')
k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim,
dtype=torch.float16, device='cuda')
cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda')
out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size)
# Verify all batches have correct cache length
assert (cache_seqlens == seq_len).all(), f"Cache seqlens mismatch: {cache_seqlens}"
assert out.shape == (batch_size, seq_len, num_heads, head_dim)
# Compare with reference for each batch item
k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2)
v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2)
ref_out = reference_attention(q, k_expanded, v_expanded)
max_diff = (ref_out - out).abs().max().item()
print(f" Output shape: {out.shape}")
print(f" Max diff vs reference: {max_diff:.6f}")
if max_diff < 1e-2:
print(f" PASSED")
return True
else:
print(f" FAILED")
return False
# ============================================================
# Main Test Script
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Testing flash_attn_with_kvcache chunked prefill")
print("=" * 60)
all_passed = True
all_passed &= test_chunked_prefill_correctness()
all_passed &= test_incremental_decode()
all_passed &= test_batch_processing()
print("\n" + "=" * 60)
if all_passed:
print("test_flash_attn_kvcache: ALL TESTS PASSED")
else:
print("test_flash_attn_kvcache: SOME TESTS FAILED")
print("=" * 60)

View File

@@ -1,104 +0,0 @@
"""
Test FlashInfer chunked attention with CPU offload.
Uses single_prefill_with_kv_cache + merge_state for chunked KV processing.
"""
import torch
import flashinfer
# ============================================================
# Core Functions
# ============================================================
def chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk_size, kv_chunk_size):
"""
Chunked causal attention with KV on CPU.
q: [seq_q, num_heads, head_dim] on GPU
k_cpu, v_cpu: [seq_kv, num_kv_heads, head_dim] on CPU
"""
seq_q = q.shape[0]
seq_kv = k_cpu.shape[0]
final_outputs = []
for q_start in range(0, seq_q, q_chunk_size):
q_end = min(q_start + q_chunk_size, seq_q)
q_chunk = q[q_start:q_end]
merged_output = None
merged_lse = None
for kv_start in range(0, seq_kv, kv_chunk_size):
kv_end = min(kv_start + kv_chunk_size, seq_kv)
if kv_start >= q_end:
continue
k_chunk = k_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
v_chunk = v_cpu[kv_start:kv_end].to(q.device, non_blocking=True)
causal = not (kv_end <= q_start)
partial_out, partial_lse = flashinfer.single_prefill_with_kv_cache(
q_chunk, k_chunk, v_chunk,
causal=causal,
return_lse=True,
)
if merged_output is None:
merged_output, merged_lse = partial_out, partial_lse
else:
merged_output, merged_lse = flashinfer.merge_state(
merged_output, merged_lse,
partial_out, partial_lse,
)
final_outputs.append(merged_output)
return torch.cat(final_outputs, dim=0)
# ============================================================
# Main Test Script
# ============================================================
print("=" * 60)
print("Testing FlashInfer chunked attention with CPU offload")
print("=" * 60)
num_heads = 32
num_kv_heads = 8
head_dim = 128
test_configs = [
(32768, 8192, 8192), # 32K tokens
(65536, 8192, 8192), # 64K tokens
(131072, 16384, 16384), # 128K tokens
# (262144, 16384, 16384), # 256K tokens (slow)
# (524288, 16384, 16384), # 512K tokens (slow)
]
for seq_len, q_chunk, kv_chunk in test_configs:
torch.manual_seed(42)
q = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float16, device='cuda')
k_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
v_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu')
# Chunked result
chunked_out = chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk, kv_chunk)
# Reference
k_gpu = k_cpu.to('cuda')
v_gpu = v_cpu.to('cuda')
ref_out = flashinfer.single_prefill_with_kv_cache(q, k_gpu, v_gpu, causal=True)
max_diff = (ref_out - chunked_out).abs().max().item()
mean_diff = (ref_out - chunked_out).abs().mean().item()
num_chunks = (seq_len + q_chunk - 1) // q_chunk
assert max_diff < 1e-2, f"FAILED: max_diff={max_diff:.6f}"
print(f"seq={seq_len//1024}K, chunks={num_chunks}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}")
print("\ntest_flashinfer_merge: PASSED")

View File

@@ -1,134 +0,0 @@
"""
Test NanovllmSteppable: Print activation statistics at each layer.
Usage:
python tests/test_nanovllm_steppable.py
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import torch
from transformers import AutoTokenizer
from nanovllm import LLM
from nanovllm.debug.adapters.nanovllm_adapter import NanovllmSteppable
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Config
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
INPUT_LEN = 32768 # Longer context to test offload
MAX_NEW_TOKENS = 20
DTYPE = torch.float16
ENABLE_CPU_OFFLOAD = True # Test offload mode
# ============================================================
# Load Model
# ============================================================
print(f"Loading nanovllm model (cpu_offload={ENABLE_CPU_OFFLOAD})...")
llm = LLM(
MODEL_PATH,
enforce_eager=True, # Required for hooks to work
max_model_len=40960,
max_num_batched_tokens=40960,
enable_cpu_offload=ENABLE_CPU_OFFLOAD,
dtype="float16",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
# Get the underlying model for steppable
model = llm.model_runner.model
# ============================================================
# Prepare Input (using needle-in-haystack prompt)
# ============================================================
prompt, expected_answer = generate_needle_prompt(
tokenizer,
target_length=INPUT_LEN,
needle_position=0.5,
needle_value="7492",
use_chat_template=False,
verbose=True,
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
print(f"Input shape: {input_ids.shape}")
print(f"Expected answer: {expected_answer}\n")
# ============================================================
# Create Steppable Model (reused for prefill + decode)
# ============================================================
steppable = NanovllmSteppable(model)
# ============================================================
# Prefill Phase: Print activation stats
# ============================================================
print("=" * 85)
print("PREFILL PHASE")
print("=" * 85)
print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
print("-" * 85)
current_ids = input_ids.clone()
logits = None
for bp in steppable.step(current_ids, is_prefill=True):
t = bp.tensor.float()
shape_str = str(list(t.shape))
print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}")
if bp.name == "LM Head":
logits = bp.tensor
# Get first token from prefill
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
generated_tokens = [next_token]
# ============================================================
# Decode Phase: Only print generated tokens
# ============================================================
print("\n" + "=" * 85)
print("DECODE PHASE")
print("=" * 85)
print(f"Step 1: {next_token!r}")
for step in range(2, MAX_NEW_TOKENS + 1):
# Forward pass with full sequence (reuse same steppable)
# Note: nanovllm without KV cache needs full sequence for each decode
for bp in steppable.step(current_ids, is_prefill=True):
if bp.name == "LM Head":
logits = bp.tensor
# Get next token (greedy)
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
generated_tokens.append(next_token)
print(f"Step {step:2}: {next_token!r}")
# Stop if EOS
if next_token_id == tokenizer.eos_token_id:
print(" (EOS)")
break
# Append to sequence
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
# ============================================================
# Result
# ============================================================
print("\n" + "=" * 85)
print("RESULT")
print("=" * 85)
generated_text = "".join(generated_tokens)
print(f"Generated: {generated_text!r}")
print(f"Expected: {expected_answer}")
print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}")
print("\ntest_nanovllm_steppable: PASSED")

View File

@@ -1,695 +0,0 @@
"""
Test script to verify CPU offload correctness using distinctive KV patterns.
Strategy:
1. Hook into attention forward pass
2. Overwrite K/V with distinctive patterns based on chunk_idx (e.g., K=chunk_idx, V=-chunk_idx)
3. After offload to CPU, verify CPU cache contains correct patterns
4. On subsequent chunks, verify loaded KV from CPU has correct patterns
This catches bugs like:
- Wrong block being offloaded
- Wrong block being loaded
- Data corruption during transfer
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import torch
from random import randint, seed
from nanovllm import LLM, SamplingParams
from nanovllm.utils.context import get_context
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 64 * 1024
NUM_GPU_BLOCKS = 4
INPUT_LEN = 32 * 1024 # 32K tokens = 32 chunks (fits in 40 CPU blocks)
BLOCK_SIZE = 1024
# Test state
errors = []
chunk_patterns = {} # chunk_idx -> (k_pattern, v_pattern)
block_coverage = {} # chunk_idx -> set of blocks that were actually computed
load_operations = [] # List of (chunk_idx, slot_id, cpu_block_id, k_ok, v_ok) tuples
current_chunk_for_load = [0] # Mutable container to track current chunk during loads
# ============================================================
# Pattern Helpers
# ============================================================
def get_expected_pattern(chunk_idx: int):
"""Get expected K/V pattern for a chunk."""
# Use float values that are easy to identify
k_val = float(chunk_idx + 1) # 1.0, 2.0, 3.0, ...
v_val = float(-(chunk_idx + 1)) # -1.0, -2.0, -3.0, ...
return k_val, v_val
def fill_with_pattern(tensor: torch.Tensor, value: float):
"""Fill tensor with a constant value."""
tensor.fill_(value)
def check_pattern(tensor: torch.Tensor, expected: float, name: str, tolerance: float = 1e-3):
"""Check if tensor contains expected pattern."""
actual_mean = tensor.float().mean().item()
if abs(actual_mean - expected) > tolerance:
return False, f"{name}: expected mean={expected}, got {actual_mean}"
return True, None
# ============================================================
# Load Verification Instrumentation
# ============================================================
_original_load_to_slot_layer = None
_offload_engine_ref = None
def make_verified_load_to_slot_layer(original_func, offload_engine):
"""
Create a wrapper around load_to_slot_layer that verifies each load operation.
After each H2D transfer, checks that the GPU slot contains the expected
pattern from the source CPU block.
"""
def verified_load(slot_idx: int, layer_id: int, cpu_block_id: int):
# Call original load
original_func(slot_idx, layer_id, cpu_block_id)
# Only verify layer 0 to reduce overhead
if layer_id != 0:
return
# IMPORTANT: Synchronize CUDA to ensure async transfer is complete
# The transfer happens on a per-slot stream, and wait_slot_layer only
# makes compute_stream wait. We need full sync to read on default stream.
torch.cuda.synchronize()
# Get the expected pattern for this CPU block
# cpu_block_id == chunk_idx in our sequential test
expected_k, expected_v = get_expected_pattern(cpu_block_id)
# Read GPU slot data (GPU cache has no layer dimension)
gpu_k = offload_engine.k_cache_gpu[slot_idx]
gpu_v = offload_engine.v_cache_gpu[slot_idx]
actual_k = gpu_k.float().mean().item()
actual_v = gpu_v.float().mean().item()
k_ok = abs(actual_k - expected_k) < 1e-3
v_ok = abs(actual_v - expected_v) < 1e-3
chunk_idx = current_chunk_for_load[0]
load_operations.append({
'chunk_idx': chunk_idx,
'slot_idx': slot_idx,
'cpu_block_id': cpu_block_id,
'expected_k': expected_k,
'expected_v': expected_v,
'actual_k': actual_k,
'actual_v': actual_v,
'k_ok': k_ok,
'v_ok': v_ok,
})
if not (k_ok and v_ok):
errors.append(f"Load verification failed: chunk {chunk_idx}, "
f"CPU block {cpu_block_id} -> GPU slot {slot_idx}: "
f"expected K={expected_k:.1f}/V={expected_v:.1f}, "
f"got K={actual_k:.4f}/V={actual_v:.4f}")
return verified_load
def install_load_verification(llm):
"""Install verification wrapper on load_to_slot_layer."""
global _original_load_to_slot_layer, _offload_engine_ref
oe = llm.model_runner.kvcache_manager.offload_engine
_offload_engine_ref = oe
_original_load_to_slot_layer = oe.load_to_slot_layer
oe.load_to_slot_layer = make_verified_load_to_slot_layer(
_original_load_to_slot_layer, oe
)
print("Installed load verification wrapper on load_to_slot_layer")
def uninstall_load_verification():
"""Restore original load_to_slot_layer."""
global _original_load_to_slot_layer, _offload_engine_ref
if _offload_engine_ref is not None and _original_load_to_slot_layer is not None:
_offload_engine_ref.load_to_slot_layer = _original_load_to_slot_layer
print("Restored original load_to_slot_layer")
_original_load_to_slot_layer = None
_offload_engine_ref = None
# ============================================================
# Attention Hook
# ============================================================
def make_kv_pattern_pre_hook(layer_id: int):
"""
Create a PRE-forward hook that overwrites K/V with distinctive patterns BEFORE
they are stored to cache. This is called before attention.forward().
register_forward_pre_hook receives (module, inputs) and can modify inputs in-place.
"""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
q, k, v = inputs
k_pattern, v_pattern = get_expected_pattern(chunk_idx)
# === Overwrite current chunk's K/V with distinctive pattern ===
# This happens BEFORE forward(), so these values will be stored to cache
k.fill_(k_pattern)
v.fill_(v_pattern)
# Only print for first few and last few chunks to reduce noise
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
print(f"[Chunk {chunk_idx:3d}] Set K={k_pattern:.1f}, V={v_pattern:.1f}")
elif chunk_idx == 3:
print(f"... (chunks 3 to {num_chunks - 3} omitted) ...")
return hook
def make_block_coverage_pre_hook(layer_id: int):
"""
Create a PRE-forward hook to verify that all previous blocks are included
in the cpu_block_table for chunked prefill attention.
This catches bugs where:
- Some blocks are missing from the computation
- Sparse policy incorrectly filters out blocks (when not intended)
- Block table construction has off-by-one errors
"""
def hook(module, inputs):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
# Update current chunk for load verification tracking
current_chunk_for_load[0] = chunk_idx
# No previous blocks for chunk 0
if chunk_idx == 0:
return
# Get the sequence and its block table (same logic as _chunked_prefill_attention)
seq = ctx.chunked_seq if hasattr(ctx, 'chunked_seq') else None
if seq is None:
return
# Get the CPU block table that will be used for attention
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Expected blocks: 0 to chunk_idx-1 (all previous chunks)
expected_blocks = set(range(chunk_idx))
actual_blocks = set(cpu_block_table) if cpu_block_table else set()
# Store for later summary
block_coverage[chunk_idx] = {
'expected': expected_blocks,
'actual': actual_blocks,
}
# Check for missing blocks
missing_blocks = expected_blocks - actual_blocks
extra_blocks = actual_blocks - expected_blocks
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or missing_blocks:
if not missing_blocks and not extra_blocks:
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [OK]")
else:
status_parts = []
if missing_blocks:
status_parts.append(f"MISSING {sorted(missing_blocks)}")
if extra_blocks:
status_parts.append(f"EXTRA {sorted(extra_blocks)}")
print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [{', '.join(status_parts)}]")
elif chunk_idx == 4:
# Indicate that middle chunks are being verified silently
print(f" ... (verifying chunks 4-{num_chunks - 3} silently) ...")
if missing_blocks:
errors.append(f"Chunk {chunk_idx} missing blocks: {sorted(missing_blocks)}")
return hook
def make_gpu_write_verification_post_hook(layer_id: int):
"""
Create a POST-forward hook to verify the current chunk's KV was correctly
written to the GPU ring buffer write_slot.
This is a more reliable verification than checking load slots, because:
1. Post-hook runs AFTER forward() writes to GPU cache
2. write_slot mapping is deterministic: chunk_idx % num_ring_slots
3. We injected known patterns in pre-hook, now verify they're in GPU cache
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
oe = kvcache_manager.offload_engine
num_ring_slots = oe.num_ring_slots
write_slot = chunk_idx % num_ring_slots
# Get expected pattern for current chunk
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Verify write_slot contains current chunk's data (GPU cache has no layer dimension)
gpu_k = oe.k_cache_gpu[write_slot]
gpu_v = oe.v_cache_gpu[write_slot]
actual_k_mean = gpu_k.float().mean().item()
actual_v_mean = gpu_v.float().mean().item()
k_ok, _ = check_pattern(gpu_k, expected_k, f"GPU slot {write_slot}")
v_ok, _ = check_pattern(gpu_v, expected_v, f"GPU slot {write_slot}")
num_chunks = INPUT_LEN // BLOCK_SIZE
# Print for first/last chunks, or if there's an error
if True or chunk_idx >= num_chunks - 2 or not (k_ok and v_ok):
if k_ok and v_ok:
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: K={expected_k:.1f}, V={expected_v:.1f} [OK]")
else:
print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: expected K={expected_k:.1f}/V={expected_v:.1f}, "
f"got K={actual_k_mean:.2f}/V={actual_v_mean:.2f} [FAIL]")
elif chunk_idx == 4:
print(f" ... (GPU write verification for chunks 4-{num_chunks - 3} silently) ...")
if not (k_ok and v_ok):
errors.append(f"GPU write_slot {write_slot} at chunk {chunk_idx}: "
f"expected K={expected_k}, V={expected_v}, got K={actual_k_mean:.4f}, V={actual_v_mean:.4f}")
return hook
def make_kv_verification_post_hook(layer_id: int):
"""
Create a POST-forward hook to verify CPU cache contains correct patterns
from previously offloaded blocks.
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
# Only process layer 0 for cleaner output
if layer_id != 0:
return
# === Verify previously offloaded blocks in CPU cache ===
if chunk_idx >= 1:
oe = kvcache_manager.offload_engine
num_ok = 0
num_fail = 0
# Check all previously offloaded blocks
for prev_chunk in range(chunk_idx):
# CPU block ID = prev_chunk (in simple sequential case)
cpu_block_id = prev_chunk
# Get expected pattern for this block
expected_k, expected_v = get_expected_pattern(prev_chunk)
# Read from CPU cache (layer 0)
cpu_k = oe.k_cache_cpu[layer_id, cpu_block_id]
cpu_v = oe.v_cache_cpu[layer_id, cpu_block_id]
# Verify patterns
k_ok, k_err = check_pattern(cpu_k, expected_k, f"CPU K block {cpu_block_id}")
v_ok, v_err = check_pattern(cpu_v, expected_v, f"CPU V block {cpu_block_id}")
if k_ok and v_ok:
num_ok += 1
else:
num_fail += 1
if k_err:
errors.append(k_err)
if v_err:
errors.append(v_err)
# Only print summary for each chunk verification
num_chunks = INPUT_LEN // BLOCK_SIZE
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or num_fail > 0:
status = "OK" if num_fail == 0 else f"FAIL({num_fail})"
print(f" CPU verify chunk {chunk_idx:2d}: {num_ok} blocks OK [{status}]")
elif chunk_idx == 4:
print(f" ... (CPU cache verification for chunks 4-{num_chunks - 3} silently) ...")
return hook
def make_post_chunk_verification_hook(layer_id: int):
"""
Post-forward hook to verify GPU ring buffer state after attention.
"""
def hook(module, inputs, output):
ctx = get_context()
if not ctx.is_prefill or layer_id != 0:
return
chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0
kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None
if kvcache_manager is None:
return
oe = kvcache_manager.offload_engine
# After attention, the current chunk's KV should be in the GPU ring buffer
# Ring slot = chunk_idx % num_ring_slots
ring_slot = chunk_idx % oe.num_ring_slots
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Check GPU ring buffer (GPU cache has no layer dimension)
gpu_k = oe.k_cache_gpu[ring_slot]
gpu_v = oe.v_cache_gpu[ring_slot]
k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}")
v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}")
if k_ok and v_ok:
print(f" [OK] GPU slot {ring_slot} (chunk {chunk_idx}): K={expected_k}, V={expected_v}")
else:
if k_err:
print(f" [FAIL] {k_err}")
errors.append(k_err)
if v_err:
print(f" [FAIL] {v_err}")
errors.append(v_err)
return hook
def register_hooks(llm):
"""Register pre and post forward hooks."""
hooks = []
model = llm.model_runner.model
for layer_idx, decoder_layer in enumerate(model.model.layers):
attn_module = decoder_layer.self_attn.attn
# PRE-forward hook 1: Verify all previous blocks are in cpu_block_table
coverage_hook = attn_module.register_forward_pre_hook(make_block_coverage_pre_hook(layer_idx))
hooks.append(coverage_hook)
# PRE-forward hook 2: Inject K/V patterns before they're stored to cache
pattern_hook = attn_module.register_forward_pre_hook(make_kv_pattern_pre_hook(layer_idx))
hooks.append(pattern_hook)
# POST-forward hook 1: Verify GPU write_slot contains current chunk's data
gpu_verify_hook = attn_module.register_forward_hook(make_gpu_write_verification_post_hook(layer_idx))
hooks.append(gpu_verify_hook)
# POST-forward hook 2: Verify CPU cache contains correct patterns after offload
cpu_verify_hook = attn_module.register_forward_hook(make_kv_verification_post_hook(layer_idx))
hooks.append(cpu_verify_hook)
return hooks
# ============================================================
# Final Verification
# ============================================================
def verify_final_cpu_state(llm, num_chunks: int):
"""Verify all CPU blocks have correct patterns after prefill completes."""
print("\n" + "=" * 60)
print("Final CPU Cache Verification")
print("=" * 60)
kvcache_manager = llm.model_runner.kvcache_manager
oe = kvcache_manager.offload_engine
num_ok = 0
num_fail = 0
fail_details = []
# After prefill, all chunks should be in CPU
for chunk_idx in range(num_chunks):
cpu_block_id = chunk_idx
expected_k, expected_v = get_expected_pattern(chunk_idx)
# Check layer 0
cpu_k = oe.k_cache_cpu[0, cpu_block_id]
cpu_v = oe.v_cache_cpu[0, cpu_block_id]
k_ok, k_err = check_pattern(cpu_k, expected_k, f"Final CPU K block {cpu_block_id}")
v_ok, v_err = check_pattern(cpu_v, expected_v, f"Final CPU V block {cpu_block_id}")
if k_ok and v_ok:
num_ok += 1
# Only print first few and last few
if chunk_idx < 3 or chunk_idx >= num_chunks - 2:
actual_k_mean = cpu_k.float().mean().item()
actual_v_mean = cpu_v.float().mean().item()
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [OK]")
elif chunk_idx == 3:
print(f" ... (blocks 3 to {num_chunks - 3} verified OK) ...")
else:
num_fail += 1
actual_k_mean = cpu_k.float().mean().item()
actual_v_mean = cpu_v.float().mean().item()
print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), "
f"V={expected_v:.1f} ({actual_v_mean:.4f}) [FAIL]")
if k_err:
errors.append(k_err)
if v_err:
errors.append(v_err)
print(f"\nTotal: {num_ok} OK, {num_fail} FAIL out of {num_chunks} blocks")
def verify_block_coverage_summary(num_chunks: int):
"""Verify that all chunks had complete block coverage during prefill."""
print("\n" + "=" * 60)
print("Block Coverage Verification Summary")
print("=" * 60)
num_ok = 0
num_fail = 0
total_blocks_expected = 0
total_blocks_computed = 0
for chunk_idx in range(1, num_chunks): # Start from 1 (chunk 0 has no previous)
if chunk_idx not in block_coverage:
print(f" Chunk {chunk_idx}: NO COVERAGE DATA [FAIL]")
errors.append(f"Chunk {chunk_idx} has no block coverage data")
num_fail += 1
continue
coverage = block_coverage[chunk_idx]
expected = coverage['expected']
actual = coverage['actual']
missing = expected - actual
total_blocks_expected += len(expected)
total_blocks_computed += len(actual)
if not missing:
num_ok += 1
else:
num_fail += 1
# Print summary
if num_fail == 0:
print(f" All {num_ok} chunks had complete block coverage [OK]")
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
else:
print(f" {num_ok} chunks OK, {num_fail} chunks with missing blocks [FAIL]")
print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})")
# Verify the total is correct: sum of 0+1+2+...+(n-1) = n*(n-1)/2
expected_total = num_chunks * (num_chunks - 1) // 2
if total_blocks_expected == expected_total:
print(f" Expected total blocks matches formula: {expected_total} [OK]")
else:
print(f" Expected total mismatch: got {total_blocks_expected}, formula gives {expected_total} [FAIL]")
errors.append(f"Block coverage total mismatch")
def verify_load_operations_summary(num_chunks: int):
"""Verify all H2D load operations transferred correct data."""
print("\n" + "=" * 60)
print("H2D Load Operations Verification Summary")
print("=" * 60)
if not load_operations:
print(" WARNING: No load operations recorded!")
print(" (This may indicate load verification was not installed)")
return
num_ok = 0
num_fail = 0
loads_per_chunk = {}
for op in load_operations:
chunk_idx = op['chunk_idx']
if chunk_idx not in loads_per_chunk:
loads_per_chunk[chunk_idx] = []
loads_per_chunk[chunk_idx].append(op)
if op['k_ok'] and op['v_ok']:
num_ok += 1
else:
num_fail += 1
# Print per-chunk summary for first/last chunks
for chunk_idx in sorted(loads_per_chunk.keys()):
ops = loads_per_chunk[chunk_idx]
chunk_ok = sum(1 for op in ops if op['k_ok'] and op['v_ok'])
chunk_fail = len(ops) - chunk_ok
if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or chunk_fail > 0:
# Show loaded block IDs in order
block_ids = [op['cpu_block_id'] for op in ops]
if chunk_fail == 0:
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks {block_ids} [OK]")
else:
print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks, {chunk_fail} FAILED [FAIL]")
for op in ops:
if not (op['k_ok'] and op['v_ok']):
print(f" CPU block {op['cpu_block_id']} -> slot {op['slot_idx']}: "
f"expected K={op['expected_k']:.1f}/V={op['expected_v']:.1f}, "
f"got K={op['actual_k']:.4f}/V={op['actual_v']:.4f}")
elif chunk_idx == 4:
print(f" ... (chunks 4-{num_chunks - 3} load verification running silently) ...")
# Print overall summary
print(f"\n Total load operations: {len(load_operations)}")
print(f" Successful: {num_ok}, Failed: {num_fail}")
if num_fail == 0:
print(f" All H2D transfers verified correct [OK]")
else:
print(f" {num_fail} H2D transfers had incorrect data [FAIL]")
# ============================================================
# Main Test Script
# ============================================================
if __name__ == "__main__":
print("=" * 60)
print("Test: CPU Offload Correctness with Distinctive KV Patterns")
print("=" * 60)
print(f"Input: {INPUT_LEN} tokens, {INPUT_LEN // BLOCK_SIZE} chunks")
print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}")
print(f"Pattern: K=chunk_idx+1, V=-(chunk_idx+1)")
print()
# 1. Initialize LLM
print("Initializing LLM...")
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=BLOCK_SIZE,
num_gpu_blocks=NUM_GPU_BLOCKS,
dtype="float16",
)
# 2. Register hooks
hooks = register_hooks(llm)
print(f"Registered {len(hooks)} hooks")
# 3. Install load verification (instrument load_to_slot_layer)
install_load_verification(llm)
# 4. Generate prompt
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
num_chunks = INPUT_LEN // BLOCK_SIZE
# 5. Run prefill
print("\n" + "=" * 60)
print("Running Prefill with KV Pattern Injection...")
print("=" * 60)
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# 6. Remove hooks and uninstall load verification
for hook in hooks:
hook.remove()
uninstall_load_verification()
# 7. Final verification
verify_final_cpu_state(llm, num_chunks)
# 8. Block coverage summary
verify_block_coverage_summary(num_chunks)
# 9. H2D load operations summary
verify_load_operations_summary(num_chunks)
# 10. Report results
print("\n" + "=" * 60)
if errors:
print(f"test_offload_correctness: FAILED ({len(errors)} errors)")
for err in errors[:10]: # Show first 10 errors
print(f" - {err}")
exit(1)
else:
print("test_offload_correctness: PASSED")
print("=" * 60)

View File

@@ -1,119 +0,0 @@
"""
Test script for OffloadEngine - CPU-GPU KV cache transfer engine.
Demonstrates: ring buffer, H2D/D2H transfers, CUDA events, KV access.
"""
import torch
from nanovllm.kvcache.offload_engine import OffloadEngine
# ============================================================
# Utility Functions
# ============================================================
def verify(tensor: torch.Tensor, expected: float, name: str) -> None:
"""Verify tensor contains expected value."""
actual = tensor.mean().item()
assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}"
# ============================================================
# Configuration
# ============================================================
NUM_LAYERS = 4
NUM_GPU_BLOCKS = 8
NUM_CPU_BLOCKS = 16
BLOCK_SIZE = 64
NUM_KV_HEADS = 4
HEAD_DIM = 32
# ============================================================
# Main Test Script
# ============================================================
# 1. Initialize
engine = OffloadEngine(
num_layers=NUM_LAYERS,
num_gpu_blocks=NUM_GPU_BLOCKS,
num_cpu_blocks=NUM_CPU_BLOCKS,
block_size=BLOCK_SIZE,
num_kv_heads=NUM_KV_HEADS,
head_dim=HEAD_DIM,
dtype=torch.float16,
)
# 2. Ring buffer slot management
for chunk_idx in range(12):
write_slot = engine.get_write_slot_for_prefill(chunk_idx)
load_slots = engine.get_load_slots_for_prefill(write_slot)
print("chunk idx", chunk_idx, "write slots:", write_slot, "load slots:", load_slots)
assert write_slot == chunk_idx % engine.num_ring_slots
assert write_slot not in load_slots
assert engine.decode_slot == 0
assert engine.get_load_slots_for_decode() == list(range(1, NUM_GPU_BLOCKS))
# 3. Per-slot per-layer H2D transfer
engine.k_cache_cpu[0, 0].fill_(42.0)
engine.v_cache_cpu[0, 0].fill_(42.5)
engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=0)
engine.wait_slot_layer(slot_idx=1, layer_id=0)
verify(engine.k_cache_gpu[0, 1], 42.0, "H2D K")
verify(engine.v_cache_gpu[0, 1], 42.5, "H2D V")
# 4. Compute-done event (pipeline safety)
engine.record_slot_compute_done(slot_idx=1, layer_id=0)
engine.k_cache_cpu[0, 1].fill_(100.0)
engine.v_cache_cpu[0, 1].fill_(100.5)
engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=1)
engine.wait_slot_layer(slot_idx=1, layer_id=0)
verify(engine.k_cache_gpu[0, 1], 100.0, "Reuse K")
verify(engine.v_cache_gpu[0, 1], 100.5, "Reuse V")
# 5. D2H offload
engine.k_cache_gpu[1, 2].fill_(77.0)
engine.v_cache_gpu[1, 2].fill_(77.5)
engine.offload_slot_to_cpu(slot_idx=2, cpu_block_id=5)
engine.wait_slot_offload(slot_idx=2)
verify(engine.k_cache_cpu[1, 5], 77.0, "D2H K")
verify(engine.v_cache_cpu[1, 5], 77.5, "D2H V")
# 6. KV access methods
k, v = engine.get_kv_for_slot(slot_idx=1, layer_id=0)
assert k.shape == (1, BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM)
k, v = engine.get_kv_for_slots(layer_id=0, slot_indices=[0, 1, 2])
assert k.shape == (1, 3 * BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM)
engine.k_cache_gpu[0, engine.decode_slot].fill_(33.0)
k, v = engine.get_kv_for_decode_slot_accumulated(layer_id=0, num_tokens=10)
assert k.shape == (1, 10, NUM_KV_HEADS, HEAD_DIM)
verify(k, 33.0, "Decode slot K")
# 7. Batch transfer
cpu_blocks = [2, 3, 4]
gpu_slots = [3, 4, 5]
for cpu_id in cpu_blocks:
engine.k_cache_cpu[0, cpu_id].fill_(50.0 + cpu_id)
engine.load_cpu_blocks_to_gpu_slots(layer_id=0, cpu_block_ids=cpu_blocks, gpu_slot_ids=gpu_slots)
for cpu_id, gpu_slot in zip(cpu_blocks, gpu_slots):
verify(engine.k_cache_gpu[0, gpu_slot], 50.0 + cpu_id, f"Batch slot {gpu_slot}")
# 8. Gather indices (CUDA graph compatible)
engine.update_gather_indices(layer_id=0, mappings=[(0, 0), (1, 1), (2, 2)])
assert engine.gather_indices_gpu[0, :3].tolist() == [0, 1, 2]
engine.clear_gather_indices(layer_id=0)
assert engine.gather_indices_gpu[0, 0].item() == -1
print("test_offload_engine: PASSED")

View File

@@ -1,51 +0,0 @@
"""
Test script for chunked prefill with CPU offload.
Demonstrates: LLM initialization, prefill execution with CPU offload enabled.
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG"
from random import randint, seed
from nanovllm import LLM, SamplingParams
# ============================================================
# Configuration
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
MAX_MODEL_LEN = 32 * 1024
NUM_GPU_BLOCKS = 2
INPUT_LEN = 16 * 1024
# ============================================================
# Main Test Script
# ============================================================
# 1. Initialize LLM with CPU offload
llm = LLM(
MODEL_PATH,
enforce_eager=True,
max_model_len=MAX_MODEL_LEN,
max_num_batched_tokens=MAX_MODEL_LEN,
enable_cpu_offload=True,
kvcache_block_size=1024,
num_gpu_blocks=NUM_GPU_BLOCKS,
)
# 2. Generate random prompt tokens
seed(42)
prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]]
# 3. Run prefill (max_tokens=1 to focus on prefill only)
sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1)
outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False)
# 4. Verify output
assert len(outputs) == 1
assert "token_ids" in outputs[0]
assert len(outputs[0]["token_ids"]) == 1
print("test_prefill: PASSED")

199
tests/test_sequential.py Normal file
View File

@@ -0,0 +1,199 @@
"""
Sequential inference test for LLM.
Tests: After completing one prompt, the system can correctly handle
a second prompt with a clean state (first prompt's KV cache deallocated).
"""
import os
os.environ["NANOVLLM_LOG_LEVEL"] = "INFO"
import argparse
from nanovllm import LLM, SamplingParams
from utils import generate_needle_prompt, check_needle_answer
def run_sequential_test(
model_path: str,
max_model_len: int,
input_len: int,
num_gpu_blocks: int = 4,
block_size: int = 1024,
enable_cpu_offload: bool = False,
verbose: bool = True,
) -> bool:
"""
Run sequential inference test with two different prompts.
Each prompt has a different needle value. Both must be retrieved correctly.
"""
if verbose:
print(f"\n{'='*60}")
print(f"Sequential Inference Test")
print(f"{'='*60}")
print(f"Model: {model_path}")
print(f"Max model len: {max_model_len}")
print(f"Input length: {input_len}")
print(f"Block size: {block_size}")
print(f"CPU offload: {enable_cpu_offload}")
print(f"{'='*60}\n")
# Initialize LLM once
llm_kwargs = {
"enforce_eager": True,
"max_model_len": max_model_len,
"max_num_batched_tokens": max_model_len,
"enable_cpu_offload": enable_cpu_offload,
"kvcache_block_size": block_size,
}
if enable_cpu_offload:
llm_kwargs["num_gpu_blocks"] = num_gpu_blocks
llm = LLM(model_path, **llm_kwargs)
sampling_params = SamplingParams(
temperature=0.6,
max_tokens=32,
)
# ============================================================
# Test 1: First prompt with needle value "1234"
# ============================================================
needle_value_1 = "1234"
if verbose:
print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}")
prompt_1, expected_1 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_1,
)
outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True)
output_text_1 = outputs_1[0]["text"]
passed_1 = check_needle_answer(output_text_1, expected_1)
if verbose:
print(f" Expected: {expected_1}")
print(f" Output: {output_text_1[:100]}...")
print(f" Status: {'PASSED' if passed_1 else 'FAILED'}")
# ============================================================
# Test 2: Second prompt with needle value "5678"
# ============================================================
needle_value_2 = "5678"
if verbose:
print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}")
prompt_2, expected_2 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_2,
)
outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True)
output_text_2 = outputs_2[0]["text"]
passed_2 = check_needle_answer(output_text_2, expected_2)
if verbose:
print(f" Expected: {expected_2}")
print(f" Output: {output_text_2[:100]}...")
print(f" Status: {'PASSED' if passed_2 else 'FAILED'}")
# ============================================================
# Test 3: Third prompt - repeat first needle to ensure no cross-contamination
# ============================================================
needle_value_3 = "9999"
if verbose:
print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}")
prompt_3, expected_3 = generate_needle_prompt(
tokenizer=llm.tokenizer,
target_length=input_len,
needle_position=0.5,
needle_value=needle_value_3,
)
outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True)
output_text_3 = outputs_3[0]["text"]
passed_3 = check_needle_answer(output_text_3, expected_3)
if verbose:
print(f" Expected: {expected_3}")
print(f" Output: {output_text_3[:100]}...")
print(f" Status: {'PASSED' if passed_3 else 'FAILED'}")
# ============================================================
# Summary
# ============================================================
all_passed = passed_1 and passed_2 and passed_3
if verbose:
print(f"\n{'='*60}")
print(f"Summary")
print(f"{'='*60}")
print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}")
print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}")
print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}")
print(f"Overall: {'PASSED' if all_passed else 'FAILED'}")
print(f"{'='*60}\n")
return all_passed
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sequential inference test")
parser.add_argument(
"--model", "-m",
type=str,
default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"),
help="Path to model"
)
parser.add_argument(
"--max-model-len",
type=int,
default=36 * 1024,
help="Maximum model context length"
)
parser.add_argument(
"--input-len",
type=int,
default=8 * 1024,
help="Target input sequence length"
)
parser.add_argument(
"--num-gpu-blocks",
type=int,
default=2,
help="Number of GPU blocks for CPU offload"
)
parser.add_argument(
"--block-size",
type=int,
default=1024,
help="KV cache block size"
)
parser.add_argument(
"--enable-offload",
action="store_true",
help="Enable CPU offload"
)
args = parser.parse_args()
passed = run_sequential_test(
model_path=args.model,
max_model_len=args.max_model_len,
input_len=args.input_len,
num_gpu_blocks=args.num_gpu_blocks,
block_size=args.block_size,
enable_cpu_offload=args.enable_offload,
verbose=True,
)
if passed:
print("test_sequential: PASSED")
else:
print("test_sequential: FAILED")
exit(1)

View File

@@ -1,176 +0,0 @@
"""
Tests for CUDA sgDMA (cudaMemcpy2D) extension.
Author: Zijie Tian
"""
import torch
import time
from nanovllm.comm import memcpy_2d
# ============================================================
# Configuration
# ============================================================
class Config:
num_layers = 32
num_blocks = 10
block_size = 4096
num_kv_heads = 8
head_dim = 128
dtype = torch.float16
@property
def features_per_block(self):
return self.block_size * self.num_kv_heads * self.head_dim
@property
def bytes_per_block(self):
return self.features_per_block * self.dtype.itemsize
@property
def bytes_per_layer(self):
return self.num_blocks * self.bytes_per_block
# ============================================================
# Performance Benchmark
# ============================================================
def benchmark_sgdma():
"""Benchmark cudaMemcpy2D vs standard PyTorch methods."""
print("\n=== Performance Benchmark ===")
cfg = Config()
print(f" Configuration:")
print(f" num_layers: {cfg.num_layers}")
print(f" num_blocks: {cfg.num_blocks}")
print(f" block_size: {cfg.block_size}")
print(f" dtype: {cfg.dtype}")
print(f" bytes_per_block: {cfg.bytes_per_block / 1024:.1f} KB")
print(f" total transfer size: {cfg.num_layers * cfg.bytes_per_block / 1024 / 1024:.1f} MB")
num_iterations = 10
warmup = 3
test_block_id = 5
# Allocate memory
cpu_strided = torch.randn(
cfg.num_layers,
cfg.num_blocks,
cfg.features_per_block,
dtype=cfg.dtype,
pin_memory=True
)
# ========================================
# Method A: cudaMemcpy2D with sgDMA
# ========================================
gpu_buffer_a = torch.empty(cfg.num_layers, cfg.features_per_block, dtype=cfg.dtype, device='cuda')
spitch = cfg.bytes_per_layer
dpitch = cfg.bytes_per_block
width = cfg.bytes_per_block
height = cfg.num_layers
src_view = cpu_strided[:, test_block_id, :]
# Warmup
for _ in range(warmup):
memcpy_2d(gpu_buffer_a, src_view, dpitch, spitch, width, height, "h2d")
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iterations):
memcpy_2d(gpu_buffer_a, src_view, dpitch, spitch, width, height, "h2d")
torch.cuda.synchronize()
elapsed_a = time.perf_counter() - start
avg_time_a = elapsed_a / num_iterations * 1000 # ms
bandwidth_a = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_a
print(f"\n Method A (cudaMemcpy2D sgDMA):")
print(f" Avg time: {avg_time_a:.3f} ms")
print(f" Bandwidth: {bandwidth_a:.2f} GB/s")
# ========================================
# Method B: PyTorch .cuda() on strided view
# ========================================
# Warmup
for _ in range(warmup):
_ = cpu_strided[:, test_block_id, :].cuda()
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iterations):
_ = cpu_strided[:, test_block_id, :].cuda()
torch.cuda.synchronize()
elapsed_b = time.perf_counter() - start
avg_time_b = elapsed_b / num_iterations * 1000 # ms
bandwidth_b = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_b
print(f"\n Method B (PyTorch .cuda() on strided):")
print(f" Avg time: {avg_time_b:.3f} ms")
print(f" Bandwidth: {bandwidth_b:.2f} GB/s")
# ========================================
# Method C: PyTorch .cuda() on contiguous (pinned)
# ========================================
# Create contiguous version with pinned memory
cpu_contiguous = torch.empty(
cfg.num_layers,
cfg.features_per_block,
dtype=cfg.dtype,
pin_memory=True
)
cpu_contiguous.copy_(cpu_strided[:, test_block_id, :])
# Warmup
for _ in range(warmup):
_ = cpu_contiguous.cuda()
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iterations):
_ = cpu_contiguous.cuda()
torch.cuda.synchronize()
elapsed_c = time.perf_counter() - start
avg_time_c = elapsed_c / num_iterations * 1000 # ms
bandwidth_c = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_c
print(f"\n Method C (PyTorch .cuda() on contiguous):")
print(f" Avg time: {avg_time_c:.3f} ms")
print(f" Bandwidth: {bandwidth_c:.2f} GB/s")
# Summary
print(f"\n ========================================")
print(f" Performance Summary:")
print(f" Method A vs Method B: {bandwidth_a / bandwidth_b:.2f}x speedup")
print(f" Method A vs Method C: {bandwidth_a / bandwidth_c * 100:.2f}%")
print(f" ========================================")
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
print("=== CUDA sgDMA (cudaMemcpy2D) Benchmark ===")
# Check CUDA availability
if not torch.cuda.is_available():
print("CUDA not available. Skipping benchmark.")
exit(1)
# Print GPU info
print(f"Using GPU: {torch.cuda.get_device_name()}")
# Run benchmark
benchmark_sgdma()
print("\n=== Benchmark Complete ===")

View File

@@ -1,121 +0,0 @@
"""
Test TorchSteppable: Print activation statistics at each layer.
Usage:
python tests/test_torch_steppable.py
"""
import os
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
import torch
from transformers import AutoTokenizer
from modeling_qwen3 import Qwen3ForCausalLM
from nanovllm.debug.adapters.torch_adapter import TorchSteppable
from utils import generate_needle_prompt, check_needle_answer
# ============================================================
# Config
# ============================================================
MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/")
INPUT_LEN = 512
MAX_NEW_TOKENS = 20
DTYPE = torch.float16
# ============================================================
# Load Model
# ============================================================
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE)
model = model.to("cuda").eval()
# ============================================================
# Prepare Input (using needle-in-haystack prompt)
# ============================================================
prompt, expected_answer = generate_needle_prompt(
tokenizer,
target_length=INPUT_LEN,
needle_position=0.5,
needle_value="7492",
use_chat_template=False,
verbose=True,
)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
print(f"Input shape: {input_ids.shape}")
print(f"Expected answer: {expected_answer}\n")
# ============================================================
# Create Steppable Model (reused for prefill + decode)
# ============================================================
steppable = TorchSteppable(model)
# ============================================================
# Prefill Phase: Print activation stats
# ============================================================
print("=" * 85)
print("PREFILL PHASE")
print("=" * 85)
print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}")
print("-" * 85)
current_ids = input_ids.clone()
logits = None
for bp in steppable.step(current_ids):
t = bp.tensor.float()
shape_str = str(list(t.shape))
print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}")
if bp.name == "LM Head":
logits = bp.tensor
# Get first token from prefill
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
generated_tokens = [next_token]
# ============================================================
# Decode Phase: Only print generated tokens
# ============================================================
print("\n" + "=" * 85)
print("DECODE PHASE")
print("=" * 85)
print(f"Step 1: {next_token!r}")
for step in range(2, MAX_NEW_TOKENS + 1):
# Forward pass (reuse same steppable)
for bp in steppable.step(current_ids):
if bp.name == "LM Head":
logits = bp.tensor
# Get next token (greedy)
next_token_id = logits[0, -1].argmax().item()
next_token = tokenizer.decode(next_token_id)
generated_tokens.append(next_token)
print(f"Step {step:2}: {next_token!r}")
# Stop if EOS
if next_token_id == tokenizer.eos_token_id:
print(" (EOS)")
break
# Append to sequence
current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1)
# ============================================================
# Result
# ============================================================
print("\n" + "=" * 85)
print("RESULT")
print("=" * 85)
generated_text = "".join(generated_tokens)
print(f"Generated: {generated_text!r}")
print(f"Expected: {expected_answer}")
print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}")
print("\ntest_torch_steppable: PASSED")