diff --git a/CLAUDE.md b/CLAUDE.md index 2413883..6584a15 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -75,18 +75,12 @@ for hook in hooks: hook.remove() ``` -### Alignment Testing - -Use `tests/test_align.py` to compare nanovllm with reference torch implementation: - -```bash -python tests/test_align.py -``` +### Reference Implementation Key files: - `tests/modeling_qwen3.py`: Reference Qwen3 implementation (torch + transformers only) -- `tests/test_align.py`: Compares attention outputs between nanovllm and reference - `tests/test_needle_ref.py`: Reference needle test using custom Qwen3 +- `tests/test_needle.py`: Needle-in-haystack test for nanovllm ### Common Pitfalls @@ -179,7 +173,6 @@ memcpy_2d_async(gpu_buf, cpu_cache[:, block_id], dpitch, spitch, width, height, **Files**: - `csrc/sgdma_kernel.cu`, `csrc/sgdma.cpp`: CUDA extension - `nanovllm/comm/sgdma.py`: Python API -- `tests/test_sgdma.py`: Standalone benchmark - `kvcache/offload_engine.py`: Integration (4 methods updated) ### Integration Details @@ -284,25 +277,53 @@ def _merge_output_kernel(...): - Total GPU time: ~1,343 ms - **Overall speedup with Triton**: 1.67x -### Correctness Verification - -**Test**: `tests/test_chunked_attention.py` -- 12 test cases (6 configs × 2 dtypes) -- All tests PASS with max error < 0.01 -- float16: max_diff=0.000488, mean_diff~0.00001 -- bfloat16: max_diff=0.003906, mean_diff~0.0001 - ### Key Files - `nanovllm/kvcache/chunked_attention.py`: Triton kernels + merge function -- `tests/test_chunked_attention.py`: Correctness tests -- `tests/test_attention_offload.py`: Performance profiling + +## Known Issues and Fixes + +### Partial Last Block Bug (FIXED ✓) + +**Problem**: When prefill token count is not an exact multiple of `block_size`, decode outputs garbage. + +**Root Cause**: `_chunked_decode_attention` calculated `last_block_valid_tokens` using `len(seq) - 1`, which increases during decode. But CPU blocks are fixed after prefill! + +```python +# BUG: len(seq) increases each decode step +total_prefill_tokens = len(seq) - 1 # Wrong! +last_block_valid_tokens = total_prefill_tokens % block_size # Reads garbage from CPU +``` + +**Fix**: Cache original prefill length in `HybridKVCacheManager.get_prefill_len()`: + +```python +# CORRECT: Use cached prefill length +total_prefill_tokens = kvcache_manager.get_prefill_len(seq) # Fixed value +``` + +**Files Modified**: +- `nanovllm/kvcache/hybrid_manager.py`: Added `_prefill_len` dict and `get_prefill_len()` method +- `nanovllm/layers/attention.py`: Use `get_prefill_len()` instead of `len(seq) - 1` + +### Block Size 4096 Race Condition (PENDING) + +**Problem**: `block_size=4096` with multiple chunks produces garbled output. `block_size=1024` works correctly. + +**Symptoms**: +- `CUDA_LAUNCH_BLOCKING=1` makes tests pass (confirms race condition) +- `torch.cuda.synchronize()` before `store_kvcache` fixes it (heavy-handed) +- Issue specific to larger block sizes with multiple chunks + +**Current Workaround**: Default `block_size` changed from 4096 to 1024. + +**Root Cause**: Suspected race between `compute_stream`, `transfer_stream_main`, and per-slot streams during layer-by-layer offload. Investigation ongoing. ## Configuration | Parameter | Default | Notes | |-----------|---------|-------| -| `kvcache_block_size` | 1024 | Tokens per block | +| `kvcache_block_size` | 1024 | Tokens per block (changed from 4096 due to race condition) | | `max_num_batched_tokens` | 16384 | Set = max_model_len for long context | | `gpu_memory_utilization` | 0.9 | GPU memory fraction | | `enable_cpu_offload` | False | Enable for long context | diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index d91326a..f31c185 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -62,6 +62,8 @@ class LLMEngine: token_ids = self.model_runner.call("run", seqs, is_prefill) self.scheduler.postprocess(seqs, token_ids) outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] + + #> Calculate number of tokens processed num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) return outputs, num_tokens diff --git a/nanovllm/kvcache/hybrid_manager.py b/nanovllm/kvcache/hybrid_manager.py index 375b534..a1f7209 100644 --- a/nanovllm/kvcache/hybrid_manager.py +++ b/nanovllm/kvcache/hybrid_manager.py @@ -128,6 +128,9 @@ class HybridKVCacheManager(KVCacheManager): self.cpu_block_to_logical: Dict[int, int] = {} # cpu_block -> logical_id # Prefix cache (uses logical block IDs) + # NOTE: Currently WRITE-ONLY in offload mode - hashes are stored but never + #> used for cache hit detection. This is intentional: offload mode always + #> allocates new blocks and doesn't reuse existing ones. self.hash_to_logical_id: Dict[int, int] = {} # Step counter for policy @@ -258,14 +261,10 @@ class HybridKVCacheManager(KVCacheManager): pos_in_block = seq_len % self._block_size if pos_in_block == 1: - # Need new block - assert last_block.hash != -1 - + # Need new block (previous block is full) logical_id = self.free_logical_ids.popleft() block = self.logical_blocks[logical_id] block.ref_count = 1 - block.hash = -1 - block.token_ids = [] # Allocate new block to CPU (ring buffer mode) if not self.free_cpu_blocks: @@ -279,17 +278,13 @@ class HybridKVCacheManager(KVCacheManager): block_table.append(logical_id) elif pos_in_block == 0: - # Block is full, update hash for prefix cache - assert last_block.hash == -1 - token_ids = seq.block(seq.num_blocks - 1) - prefix_hash = ( - self.logical_blocks[block_table[-2]].hash - if len(block_table) > 1 else -1 - ) - h = self.compute_hash(token_ids, prefix_hash) - last_block.hash = h - last_block.token_ids = token_ids.copy() - self.hash_to_logical_id[h] = last_logical_id + # Block is full + # NOTE: Prefix cache disabled in offload mode + # If enabled, would compute hash and update: + # h = self.compute_hash(seq.block(seq.num_blocks - 1), prefix_hash) + # last_block.hash = h + # self.hash_to_logical_id[h] = last_logical_id + pass def prepare_for_attention( self, @@ -369,8 +364,6 @@ class HybridKVCacheManager(KVCacheManager): """ assert not seq.block_table, "Sequence already has blocks" - h = -1 # Running hash for prefix cache - for i in range(seq.num_blocks): # Allocate CPU block if not self.free_cpu_blocks: @@ -381,19 +374,10 @@ class HybridKVCacheManager(KVCacheManager): cpu_block_id = self.free_cpu_blocks.popleft() - # Get token IDs for this block and compute hash - token_ids = seq.block(i) - if len(token_ids) == self._block_size: - h = self.compute_hash(token_ids, h) - else: - h = -1 # Incomplete block - # Allocate logical block logical_id = self.free_logical_ids.popleft() block = self.logical_blocks[logical_id] block.ref_count = 1 - block.hash = h - block.token_ids = token_ids.copy() if len(token_ids) == self._block_size else [] block.location = BlockLocation.CPU block.cpu_block_id = cpu_block_id block.gpu_slot = -1 @@ -401,9 +385,11 @@ class HybridKVCacheManager(KVCacheManager): self.cpu_block_to_logical[cpu_block_id] = logical_id seq.block_table.append(logical_id) - # Update prefix cache - if h != -1: - self.hash_to_logical_id[h] = logical_id + # NOTE: Prefix cache disabled in offload mode + # If enabled, would compute hash and update: + # h = self.compute_hash(seq.block(i), prefix_hash) + # block.hash = h + # self.hash_to_logical_id[h] = logical_id def get_cpu_block_table(self, seq: Sequence) -> List[int]: """ diff --git a/tests/sgdma_cpp/CMakeLists.txt b/tests/sgdma_cpp/CMakeLists.txt deleted file mode 100644 index 3f3800a..0000000 --- a/tests/sgdma_cpp/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -cmake_minimum_required(VERSION 3.18) -project(sgdma_test CUDA CXX) - -# Find CUDA -enable_language(CUDA) -find_package(CUDA REQUIRED) - -# Set C++ standard -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CUDA_STANDARD 17) - -# CUDA flags -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 --use_fast_math") -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3") - -# Build test executable -add_executable(sgdma_test sgdma_test.cpp) -target_link_libraries(sgdma_test cudart) - -# Set output directory -set_target_properties(sgdma_test PROPERTIES - RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin -) diff --git a/tests/sgdma_cpp/sgdma_test.cpp b/tests/sgdma_cpp/sgdma_test.cpp deleted file mode 100644 index 178dd67..0000000 --- a/tests/sgdma_cpp/sgdma_test.cpp +++ /dev/null @@ -1,326 +0,0 @@ -#include -#include -#include -#include -#include -#include - -// CUDA error checking macro -#define CUDA_CHECK(call) do { \ - cudaError_t err = call; \ - if (err != cudaSuccess) { \ - std::cerr << "CUDA Error in " << __FILE__ << " at line " << __LINE__ << ": " \ - << cudaGetErrorString(err) << std::endl; \ - exit(EXIT_FAILURE); \ - } \ -} while (0) - -// Configuration matching nano-vllm realistic parameters -struct Config { - int num_layers = 32; - int num_blocks = 10; // Reduced from 100 to avoid huge allocation - int block_size = 4096; - int num_kv_heads = 8; - int head_dim = 128; - int dtype_size = 2; // float16 - - // Derived parameters (use size_t to avoid overflow) - size_t features_per_block() const { return (size_t)block_size * num_kv_heads * head_dim; } - size_t bytes_per_block() const { return features_per_block() * dtype_size; } - int total_blocks_per_layer() const { return num_blocks; } - size_t bytes_per_layer() const { return (size_t)num_blocks * bytes_per_block(); } - size_t total_bytes() const { return (size_t)num_layers * bytes_per_layer(); } -}; - -// Timer utility -class Timer { - std::chrono::high_resolution_clock::time_point start_time; -public: - void start() { start_time = std::chrono::high_resolution_clock::now(); } - double elapsed_ms() { - auto end = std::chrono::high_resolution_clock::now(); - return std::chrono::duration(end - start_time).count(); - } -}; - -// Initialize CPU memory with test pattern -void init_test_data(void* data, size_t bytes, int seed) { - uint16_t* ptr = static_cast(data); - size_t num_elements = bytes / sizeof(uint16_t); - for (size_t i = 0; i < num_elements; i++) { - ptr[i] = static_cast((seed + i) % 65536); - } -} - -// Verify data correctness -bool verify_data(const void* data1, const void* data2, size_t bytes) { - const uint16_t* p1 = static_cast(data1); - const uint16_t* p2 = static_cast(data2); - size_t num_elements = bytes / sizeof(uint16_t); - - for (size_t i = 0; i < num_elements; i++) { - if (p1[i] != p2[i]) { - std::cerr << "Mismatch at element " << i << ": " - << p1[i] << " != " << p2[i] << std::endl; - return false; - } - } - return true; -} - -// ============================================================ -// Test 1: Basic Functionality Test -// ============================================================ -bool test_basic_functionality(const Config& cfg) { - std::cout << "\n[Test 1] Basic Functionality Test" << std::endl; - std::cout << " Testing cudaMemcpy2D correctness with strided layout" << std::endl; - - // Allocate strided CPU memory (pinned) - // Layout: [num_layers, num_blocks, block_features] - size_t total_bytes = cfg.total_bytes(); - std::cout << " Allocating " << total_bytes / 1024.0 / 1024.0 / 1024.0 << " GB pinned memory..." << std::endl; - void* cpu_strided = nullptr; - CUDA_CHECK(cudaMallocHost(&cpu_strided, total_bytes)); - std::cout << " CPU strided memory allocated at: " << cpu_strided << std::endl; - - // Allocate GPU memory for one block (all layers) - size_t gpu_block_bytes = cfg.num_layers * cfg.bytes_per_block(); - void* gpu_data = nullptr; - CUDA_CHECK(cudaMalloc(&gpu_data, gpu_block_bytes)); - - // Allocate CPU verify buffer - void* cpu_verify = nullptr; - CUDA_CHECK(cudaMallocHost(&cpu_verify, gpu_block_bytes)); - - // Initialize strided CPU memory - init_test_data(cpu_strided, total_bytes, 12345); - - // Test: Copy block_id=5 from CPU to GPU using cudaMemcpy2D - int test_block_id = 5; - size_t spitch = cfg.bytes_per_layer(); // Source pitch (stride between layers) - size_t dpitch = cfg.bytes_per_block(); // Destination pitch (contiguous) - size_t width = cfg.bytes_per_block(); // Width to copy per row - size_t height = cfg.num_layers; // Number of rows (layers) - - // Debug: print parameters - std::cout << " cudaMemcpy2D parameters:" << std::endl; - std::cout << " spitch: " << spitch << " bytes" << std::endl; - std::cout << " dpitch: " << dpitch << " bytes" << std::endl; - std::cout << " width: " << width << " bytes" << std::endl; - std::cout << " height: " << height << " rows" << std::endl; - std::cout << " dpitch >= width: " << (dpitch >= width ? "yes" : "no") << std::endl; - std::cout << " spitch >= width: " << (spitch >= width ? "yes" : "no") << std::endl; - - // Calculate source pointer (first layer, block_id) - uint8_t* src_ptr = static_cast(cpu_strided) + test_block_id * cfg.bytes_per_block(); - - // H2D transfer - CUDA_CHECK(cudaMemcpy2D( - gpu_data, // dst - dpitch, // dpitch - src_ptr, // src - spitch, // spitch - width, // width - height, // height - cudaMemcpyHostToDevice - )); - - // D2H transfer back - CUDA_CHECK(cudaMemcpy2D( - cpu_verify, // dst - dpitch, // dpitch - gpu_data, // src - dpitch, // spitch - width, // width - height, // height - cudaMemcpyDeviceToHost - )); - - // Verify correctness - bool passed = true; - for (int layer = 0; layer < cfg.num_layers; layer++) { - uint8_t* expected_ptr = static_cast(cpu_strided) + - layer * cfg.bytes_per_layer() + - test_block_id * cfg.bytes_per_block(); - uint8_t* actual_ptr = static_cast(cpu_verify) + - layer * cfg.bytes_per_block(); - - if (!verify_data(expected_ptr, actual_ptr, cfg.bytes_per_block())) { - std::cerr << " Verification failed at layer " << layer << std::endl; - passed = false; - break; - } - } - - // Cleanup - CUDA_CHECK(cudaFreeHost(cpu_strided)); - CUDA_CHECK(cudaFreeHost(cpu_verify)); - CUDA_CHECK(cudaFree(gpu_data)); - - std::cout << " Result: " << (passed ? "PASSED ✓" : "FAILED ✗") << std::endl; - return passed; -} - -// ============================================================ -// Test 2: Performance Benchmark -// ============================================================ -void test_performance_benchmark(const Config& cfg) { - std::cout << "\n[Test 2] Performance Benchmark" << std::endl; - std::cout << " Configuration:" << std::endl; - std::cout << " num_layers: " << cfg.num_layers << std::endl; - std::cout << " num_blocks: " << cfg.num_blocks << std::endl; - std::cout << " block_size: " << cfg.block_size << std::endl; - std::cout << " num_kv_heads: " << cfg.num_kv_heads << std::endl; - std::cout << " head_dim: " << cfg.head_dim << std::endl; - std::cout << " dtype_size: " << cfg.dtype_size << " bytes" << std::endl; - std::cout << " bytes_per_block: " << cfg.bytes_per_block() / 1024.0 << " KB" << std::endl; - std::cout << " total transfer size: " << cfg.num_layers * cfg.bytes_per_block() / 1024.0 / 1024.0 << " MB" << std::endl; - - const int num_iterations = 100; - const int warmup = 10; - int test_block_id = 5; - - // Allocate memory - size_t total_bytes = cfg.total_bytes(); - void* cpu_strided = nullptr; - CUDA_CHECK(cudaMallocHost(&cpu_strided, total_bytes)); - - void* cpu_contiguous = nullptr; - size_t gpu_block_bytes = cfg.num_layers * cfg.bytes_per_block(); - CUDA_CHECK(cudaMallocHost(&cpu_contiguous, gpu_block_bytes)); - - void* gpu_data = nullptr; - CUDA_CHECK(cudaMalloc(&gpu_data, gpu_block_bytes)); - - init_test_data(cpu_strided, total_bytes, 12345); - init_test_data(cpu_contiguous, gpu_block_bytes, 12345); - - Timer timer; - double elapsed; - double bandwidth; - - // ======================================== - // Method A: cudaMemcpy2D with strided layout - // ======================================== - size_t spitch = cfg.bytes_per_layer(); - size_t dpitch = cfg.bytes_per_block(); - size_t width = cfg.bytes_per_block(); - size_t height = cfg.num_layers; - uint8_t* src_ptr = static_cast(cpu_strided) + test_block_id * cfg.bytes_per_block(); - - // Warmup - for (int i = 0; i < warmup; i++) { - CUDA_CHECK(cudaMemcpy2D(gpu_data, dpitch, src_ptr, spitch, width, height, cudaMemcpyHostToDevice)); - } - CUDA_CHECK(cudaDeviceSynchronize()); - - // Benchmark - timer.start(); - for (int i = 0; i < num_iterations; i++) { - CUDA_CHECK(cudaMemcpy2D(gpu_data, dpitch, src_ptr, spitch, width, height, cudaMemcpyHostToDevice)); - } - CUDA_CHECK(cudaDeviceSynchronize()); - elapsed = timer.elapsed_ms(); - bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0); - - std::cout << "\n Method A (cudaMemcpy2D strided):" << std::endl; - std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl; - std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl; - double method_a_bw = bandwidth; - - // ======================================== - // Method B: cudaMemcpy with contiguous layout (baseline) - // ======================================== - // Warmup - for (int i = 0; i < warmup; i++) { - CUDA_CHECK(cudaMemcpy(gpu_data, cpu_contiguous, gpu_block_bytes, cudaMemcpyHostToDevice)); - } - CUDA_CHECK(cudaDeviceSynchronize()); - - // Benchmark - timer.start(); - for (int i = 0; i < num_iterations; i++) { - CUDA_CHECK(cudaMemcpy(gpu_data, cpu_contiguous, gpu_block_bytes, cudaMemcpyHostToDevice)); - } - CUDA_CHECK(cudaDeviceSynchronize()); - elapsed = timer.elapsed_ms(); - bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0); - - std::cout << "\n Method B (cudaMemcpy contiguous):" << std::endl; - std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl; - std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl; - double method_b_bw = bandwidth; - - // ======================================== - // Method C: Layer-by-layer copy (simulate PyTorch non-contiguous) - // ======================================== - // Warmup - for (int i = 0; i < warmup; i++) { - for (int layer = 0; layer < cfg.num_layers; layer++) { - uint8_t* src_layer = static_cast(cpu_strided) + - layer * cfg.bytes_per_layer() + - test_block_id * cfg.bytes_per_block(); - uint8_t* dst_layer = static_cast(gpu_data) + layer * cfg.bytes_per_block(); - CUDA_CHECK(cudaMemcpy(dst_layer, src_layer, cfg.bytes_per_block(), cudaMemcpyHostToDevice)); - } - } - CUDA_CHECK(cudaDeviceSynchronize()); - - // Benchmark - timer.start(); - for (int i = 0; i < num_iterations; i++) { - for (int layer = 0; layer < cfg.num_layers; layer++) { - uint8_t* src_layer = static_cast(cpu_strided) + - layer * cfg.bytes_per_layer() + - test_block_id * cfg.bytes_per_block(); - uint8_t* dst_layer = static_cast(gpu_data) + layer * cfg.bytes_per_block(); - CUDA_CHECK(cudaMemcpy(dst_layer, src_layer, cfg.bytes_per_block(), cudaMemcpyHostToDevice)); - } - } - CUDA_CHECK(cudaDeviceSynchronize()); - elapsed = timer.elapsed_ms(); - bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0); - - std::cout << "\n Method C (layer-by-layer copy):" << std::endl; - std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl; - std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl; - double method_c_bw = bandwidth; - - // Summary - std::cout << "\n ========================================" << std::endl; - std::cout << " Performance Summary:" << std::endl; - std::cout << " Method A vs Method B: " << std::setprecision(2) << (method_a_bw / method_b_bw * 100) << "%" << std::endl; - std::cout << " Method A vs Method C: " << std::setprecision(2) << (method_a_bw / method_c_bw) << "x speedup" << std::endl; - std::cout << " ========================================" << std::endl; - - // Cleanup - CUDA_CHECK(cudaFreeHost(cpu_strided)); - CUDA_CHECK(cudaFreeHost(cpu_contiguous)); - CUDA_CHECK(cudaFree(gpu_data)); -} - -int main() { - std::cout << "=== cudaMemcpy2D Test ===" << std::endl; - - // Print CUDA device info - int device; - CUDA_CHECK(cudaGetDevice(&device)); - cudaDeviceProp prop; - CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); - std::cout << "Using GPU: " << prop.name << std::endl; - std::cout << "Memory Clock Rate: " << prop.memoryClockRate / 1000 << " MHz" << std::endl; - std::cout << "Memory Bus Width: " << prop.memoryBusWidth << " bits" << std::endl; - std::cout << "Peak Memory Bandwidth: " << - 2.0 * prop.memoryClockRate * (prop.memoryBusWidth / 8) / 1.0e6 << " GB/s" << std::endl; - - Config cfg; - - // Run tests - bool test1_passed = test_basic_functionality(cfg); - test_performance_benchmark(cfg); - - std::cout << "\n=== Test Complete ===" << std::endl; - std::cout << "All tests " << (test1_passed ? "PASSED ✓" : "FAILED ✗") << std::endl; - - return test1_passed ? 0 : 1; -} diff --git a/tests/test_align.py b/tests/test_align.py deleted file mode 100644 index 798faf6..0000000 --- a/tests/test_align.py +++ /dev/null @@ -1,365 +0,0 @@ -""" -Test alignment between nanovllm and custom torch Qwen3 implementation. -Compares attention layer outputs and QKV tensors to verify correctness. - -Usage: - python test_align.py # Without CPU offload - python test_align.py --enable-offload # With CPU offload - python test_align.py --input-len 4096 # Custom input length -""" - -import os -os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" - -import argparse -import torch -from transformers import AutoTokenizer -from nanovllm import LLM, SamplingParams -from modeling_qwen3 import Qwen3ForCausalLM -from utils import generate_needle_prompt - -# Parse arguments -parser = argparse.ArgumentParser() -parser.add_argument("--enable-offload", action="store_true", help="Enable CPU offload") -parser.add_argument("--input-len", type=int, default=1024 * 12, help="Input sequence length") -parser.add_argument("--model-path", type=str, default="~/models/Qwen3-0.6B/", help="Model path") -parser.add_argument("--num-gpu-blocks", type=int, default=6, help="Number of GPU blocks (ring buffer slots)") -parser.add_argument("--block-size", type=int, default=1024, help="KV cache block size") -args = parser.parse_args() - -# Config -MODEL_PATH = os.path.expanduser(args.model_path) -INPUT_LEN = args.input_len -ENABLE_OFFLOAD = args.enable_offload -NUM_GPU_BLOCKS = args.num_gpu_blocks -BLOCK_SIZE = args.block_size -DTYPE = torch.float16 - -print(f"Config: input_len={INPUT_LEN}, enable_offload={ENABLE_OFFLOAD}, num_gpu_blocks={NUM_GPU_BLOCKS}, block_size={BLOCK_SIZE}") - -# Storage for captured tensors -nanovllm_outputs = {} -torch_outputs = {} -nanovllm_qkv = {} -nanovllm_proj_inputs = {} -torch_proj_inputs = {} - - -# ============================================================ -# Hook functions for non-offload mode (overwrite) -# ============================================================ -def make_nanovllm_hook(layer_id: int, storage: dict): - def hook(module, inputs, output): - attn_output = output[0] if isinstance(output, tuple) else output - if attn_output.dim() == 2: - attn_output = attn_output.unsqueeze(0) - storage[layer_id] = attn_output.detach().clone() - return hook - - -def make_nanovllm_qkv_hook(layer_id: int, storage: dict): - def hook(module, inputs): - q, k, v = inputs[0], inputs[1], inputs[2] - storage[layer_id] = { - "q": q.detach().clone(), - "k": k.detach().clone(), - "v": v.detach().clone(), - } - return hook - - -def make_proj_input_hook(layer_id: int, storage: dict): - def hook(module, inputs): - hidden = inputs[0] - if hidden.dim() == 2: - hidden = hidden.unsqueeze(0) - storage[layer_id] = hidden.detach().clone() - return hook - - -# ============================================================ -# Hook functions for offload mode (accumulate Q and I, overwrite O) -# ============================================================ -def make_accumulating_q_hook(layer_id: int, storage: dict): - """Accumulate Q from each chunk for offload mode.""" - def hook(module, inputs): - q = inputs[0].detach().clone() - if layer_id not in storage: - storage[layer_id] = [] - storage[layer_id].append(q) - return hook - - -def make_accumulating_input_hook(layer_id: int, storage: dict): - """Accumulate input hidden states from each chunk for offload mode.""" - def hook(module, inputs): - hidden = inputs[0].detach().clone() - if layer_id not in storage: - storage[layer_id] = [] - storage[layer_id].append(hidden) - return hook - - -def make_overwrite_output_hook(layer_id: int, storage: dict): - """Overwrite output (keep only last chunk) for offload mode.""" - def hook(module, inputs, output): - attn_output = output[0] if isinstance(output, tuple) else output - if attn_output.dim() == 2: - attn_output = attn_output.unsqueeze(0) - storage[layer_id] = attn_output.detach().clone() - return hook - - -# ============================================================ -# CPU KV cache access for offload mode -# ============================================================ -def get_nanovllm_kv_from_cpu(llm, seq, num_layers): - """Get complete K, V cache from CPU side after all chunks finish.""" - offload_engine = llm.model_runner.kvcache_manager.offload_engine - kvcache_manager = llm.model_runner.kvcache_manager - - # CRITICAL: Synchronize all CUDA operations before reading CPU memory - # The D2H copy runs on transfer_stream_main and may still be in progress - torch.cuda.synchronize() - - cpu_block_ids = kvcache_manager.get_cpu_block_table(seq) - - kv_per_layer = {} - for layer_id in range(num_layers): - k_blocks = [] - v_blocks = [] - for cpu_block_id in cpu_block_ids: - k_block, v_block = offload_engine.get_cpu_block(layer_id, cpu_block_id) - k_blocks.append(k_block) - v_blocks.append(v_block) - - # Concatenate all blocks: [total_tokens, kv_heads, head_dim] - k_full = torch.cat(k_blocks, dim=0)[:seq.num_tokens] - v_full = torch.cat(v_blocks, dim=0)[:seq.num_tokens] - kv_per_layer[layer_id] = {"k": k_full, "v": v_full} - - return kv_per_layer - - -def make_torch_hook(layer_id: int, storage: dict): - def hook(module, inputs, output): - storage[layer_id] = output[0].detach().clone() - return hook - - -def cosine_sim(t1: torch.Tensor, t2: torch.Tensor) -> float: - """Cosine similarity between flattened tensors (1.0 = identical).""" - return torch.nn.functional.cosine_similarity( - t1.flatten().float(), t2.flatten().float(), dim=0 - ).item() - - -def compute_qkv_sims(nano_qkv: dict, torch_qkv: dict, num_kv_groups: int): - """Compute Q, K, V cosine similarities. Returns (q_sim, k_sim, v_sim).""" - nano_q = nano_qkv["q"] - torch_q = torch_qkv["q"].squeeze(0).transpose(0, 1) - - nano_k = nano_qkv["k"] - torch_k = torch_qkv["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1) - - nano_v = nano_qkv["v"] - torch_v = torch_qkv["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1) - - return cosine_sim(nano_q, torch_q), cosine_sim(nano_k, torch_k), cosine_sim(nano_v, torch_v) - - -# ============================================================ -# Load models -# ============================================================ -print("Loading nanovllm model...") -llm_kwargs = dict( - enforce_eager=True, - max_model_len=32768, - gpu_memory_utilization=0.2, - max_num_batched_tokens=32768, - enable_cpu_offload=ENABLE_OFFLOAD, - dtype="float16", - kvcache_block_size=BLOCK_SIZE, -) -if ENABLE_OFFLOAD: - llm_kwargs["num_gpu_blocks"] = NUM_GPU_BLOCKS -llm = LLM(MODEL_PATH, **llm_kwargs) - -num_heads = llm.model_runner.model.model.layers[0].self_attn.num_heads -num_kv_heads = llm.model_runner.model.model.layers[0].self_attn.num_kv_heads -num_kv_groups = num_heads // num_kv_heads -num_layers = len(llm.model_runner.model.model.layers) - -print("Loading torch model...") -torch_model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE) -torch_model = torch_model.to("cuda") -torch_model.eval() - -# ============================================================ -# Generate test input -# ============================================================ -tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) -prompt, _ = generate_needle_prompt(tokenizer=tokenizer, target_length=INPUT_LEN, verbose=True) -input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda") -print(f"Input shape: {input_ids.shape}") - -# ============================================================ -# Register hooks -# ============================================================ -nanovllm_hooks = [] -nanovllm_q_accum = {} # For offload mode: accumulated Q from all chunks -nanovllm_i_accum = {} # For offload mode: accumulated I from all chunks - -for layer_idx, layer in enumerate(llm.model_runner.model.model.layers): - if ENABLE_OFFLOAD: - # Offload mode: accumulate Q and I, overwrite O - nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_overwrite_output_hook(layer_idx, nanovllm_outputs))) - nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_accumulating_q_hook(layer_idx, nanovllm_q_accum))) - nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_accumulating_input_hook(layer_idx, nanovllm_i_accum))) - else: - # Non-offload mode: overwrite all - nanovllm_hooks.append(layer.self_attn.register_forward_hook(make_nanovllm_hook(layer_idx, nanovllm_outputs))) - nanovllm_hooks.append(layer.self_attn.attn.register_forward_pre_hook(make_nanovllm_qkv_hook(layer_idx, nanovllm_qkv))) - nanovllm_hooks.append(layer.self_attn.qkv_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, nanovllm_proj_inputs))) - -torch_hooks = [] -for layer_idx, layer in enumerate(torch_model.model.layers): - torch_hooks.append(layer.self_attn.register_forward_hook(make_torch_hook(layer_idx, torch_outputs))) - torch_hooks.append(layer.self_attn.q_proj.register_forward_pre_hook(make_proj_input_hook(layer_idx, torch_proj_inputs))) - -# ============================================================ -# Run inference -# ============================================================ -print("Running nanovllm inference...") - -if ENABLE_OFFLOAD: - # Manual execution to capture KV cache before deallocation - # Use max_tokens=2 so sequence doesn't finish immediately after prefill - llm.add_request(input_ids[0].tolist(), SamplingParams(temperature=0.01, max_tokens=2)) - - # Run prefill step (this calls run_chunked_offload_prefill internally) - output, num_tokens = llm.step() - print(f"[Offload] Prefill done: {num_tokens} tokens") - - # Now seq is in running queue, KV cache is in CPU - seq = llm.scheduler.running[0] - print(f"[Offload] Sequence: {seq}") - - # Get KV cache from CPU BEFORE decode step deallocates it - nanovllm_kv_cpu = get_nanovllm_kv_from_cpu(llm, seq, num_layers) - print(f"[Offload] Retrieved KV cache from CPU for {seq.num_tokens} tokens") - - # IMPORTANT: Save outputs NOW before decode step overwrites them - # nanovllm_outputs contains prefill attention outputs at this point - nanovllm_outputs_prefill = {k: v.clone() for k, v in nanovllm_outputs.items()} - - # Complete remaining steps (decode) - while not llm.is_finished(): - llm.step() - - # Use prefill outputs for comparison - nanovllm_outputs = nanovllm_outputs_prefill -else: - nanovllm_result = llm.generate([input_ids[0].tolist()], SamplingParams(temperature=0.01, max_tokens=1), use_tqdm=False) - -print("Running torch inference...") -with torch.no_grad(): - torch_logits, _, torch_qkv_outputs = torch_model(input_ids, output_qkv_layers=list(range(num_layers))) - -# ============================================================ -# Compare using cosine similarity (1.0 = perfect alignment) -# ============================================================ -print("\n" + "=" * 70) -print(f"{'Layer':<8} {'I':>10} {'Q':>10} {'K':>10} {'V':>10} {'O':>10}") -print("=" * 70) - -all_passed = True -threshold = 0.999 # Cosine similarity threshold - -for layer_idx in range(num_layers): - if ENABLE_OFFLOAD: - # ============================================================ - # Offload mode: use accumulated Q/I and CPU-side K/V - # Only compare prompt tokens (INPUT_LEN), exclude generated tokens - # ============================================================ - # I: concatenate accumulated chunks, trim to prompt length - i_chunks = nanovllm_i_accum[layer_idx] - nano_in = torch.cat(i_chunks, dim=0)[:INPUT_LEN] - if nano_in.dim() == 2: - nano_in = nano_in.unsqueeze(0) - torch_in = torch_proj_inputs[layer_idx] - if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel(): - torch_in = torch_in.view(nano_in.shape) - i_sim = cosine_sim(nano_in, torch_in) - - # Q: concatenate accumulated chunks, trim to prompt length - q_chunks = nanovllm_q_accum[layer_idx] - nano_q = torch.cat(q_chunks, dim=0)[:INPUT_LEN] - torch_q = torch_qkv_outputs[layer_idx]["q"].squeeze(0).transpose(0, 1) - q_sim = cosine_sim(nano_q, torch_q) - - # K, V: from CPU cache, trim to prompt length and move to GPU - nano_k = nanovllm_kv_cpu[layer_idx]["k"][:INPUT_LEN].cuda() - torch_k = torch_qkv_outputs[layer_idx]["k"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1) - k_sim = cosine_sim(nano_k, torch_k) - - nano_v = nanovllm_kv_cpu[layer_idx]["v"][:INPUT_LEN].cuda() - torch_v = torch_qkv_outputs[layer_idx]["v"].squeeze(0)[::num_kv_groups, :, :].transpose(0, 1) - v_sim = cosine_sim(nano_v, torch_v) - - # O: compare attention outputs directly - # For single-chunk case (input_len <= block_size), shapes should match - # For multi-chunk case, nano_out is the last chunk only - nano_out = nanovllm_outputs[layer_idx] - torch_out = torch_outputs[layer_idx] - - if nano_out.numel() == torch_out.numel(): - # Single chunk or shapes match - compare directly - o_sim = cosine_sim(nano_out, torch_out) - else: - # Multi-chunk case: compare last chunk with corresponding torch slice - last_chunk_len = nano_out.shape[1] if nano_out.dim() == 3 else nano_out.shape[0] - torch_out_slice = torch_out[:, -last_chunk_len:, :] if torch_out.dim() == 3 else torch_out[-last_chunk_len:, :] - o_sim = cosine_sim(nano_out, torch_out_slice) - else: - # ============================================================ - # Non-offload mode: original logic - # ============================================================ - # Input similarity - nano_in = nanovllm_proj_inputs[layer_idx] - torch_in = torch_proj_inputs[layer_idx] - - if nano_in.shape != torch_in.shape and nano_in.numel() == torch_in.numel(): - torch_in = torch_in.view(nano_in.shape) - - i_sim = cosine_sim(nano_in, torch_in) - - # QKV similarities - q_sim, k_sim, v_sim = compute_qkv_sims(nanovllm_qkv[layer_idx], torch_qkv_outputs[layer_idx], num_kv_groups) - - # O similarity - nano_out = nanovllm_outputs[layer_idx] - torch_out = torch_outputs[layer_idx] - if nano_out.shape != torch_out.shape and nano_out.numel() == torch_out.numel(): - torch_out = torch_out.view(nano_out.shape) - o_sim = cosine_sim(nano_out, torch_out) - - # Check pass/fail - passed = all(s >= threshold for s in [i_sim, q_sim, k_sim, v_sim, o_sim]) - all_passed = all_passed and passed - status = "" if passed else " *" - - print(f"Layer {layer_idx:2d}{status:<3} {i_sim:>10.6f} {q_sim:>10.6f} {k_sim:>10.6f} {v_sim:>10.6f} {o_sim:>10.6f}") - -# ============================================================ -# Cleanup and result -# ============================================================ -for hook in nanovllm_hooks + torch_hooks: - hook.remove() - -print("=" * 70) -mode_str = " [offload]" if ENABLE_OFFLOAD else "" -if all_passed: - print(f"test_align{mode_str}: PASSED (cosine_sim >= 0.999)") -else: - print(f"test_align{mode_str}: FAILED (* = cosine_sim < 0.999)") diff --git a/tests/test_attention_offload.py b/tests/test_attention_offload.py deleted file mode 100644 index 5fcbeac..0000000 --- a/tests/test_attention_offload.py +++ /dev/null @@ -1,297 +0,0 @@ -""" -Test Attention layer with KV cache offload - N-way Pipeline. - -This test demonstrates and verifies the N-way pipeline with: -- Per-slot transfer streams for parallel H2D -- Dedicated compute stream (avoids CUDA default stream implicit sync) -- Pre-load phase + main loop with immediate slot reuse - -Key difference from previous test: -- We first pre-fill many chunks to CPU cache -- Then simulate processing a new chunk that loads ALL previous blocks -- This exercises the full N-way pipeline with many blocks in flight -""" - -import torch -from nanovllm.layers.attention import Attention -from nanovllm.kvcache.hybrid_manager import HybridKVCacheManager -from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs -from nanovllm.engine.sequence import Sequence -from nanovllm.utils.context import set_context, reset_context - - -# ============================================================ -# Configuration -# ============================================================ - -NUM_LAYERS = 8 -NUM_HEADS = 8 -NUM_KV_HEADS = 8 -HEAD_DIM = 64 -BLOCK_SIZE = 1024 -CHUNK_SIZE = 1024 - -NUM_GPU_SLOTS = 6 # N-way pipeline with 6 slots -NUM_CPU_BLOCKS = 16 # Many blocks to load from CPU - -DTYPE = torch.bfloat16 -DEVICE = "cuda" - - -# ============================================================ -# Setup -# ============================================================ - -def create_manager(): - manager = HybridKVCacheManager( - num_gpu_slots=NUM_GPU_SLOTS, - num_cpu_blocks=NUM_CPU_BLOCKS, - block_size=BLOCK_SIZE, - ) - manager.allocate_cache( - num_layers=NUM_LAYERS, - num_kv_heads=NUM_KV_HEADS, - head_dim=HEAD_DIM, - dtype=DTYPE, - ) - return manager - - -def create_attention_layers(manager): - layers = [] - for layer_id in range(NUM_LAYERS): - attn = Attention( - num_heads=NUM_HEADS, - head_dim=HEAD_DIM, - scale=HEAD_DIM ** -0.5, - num_kv_heads=NUM_KV_HEADS, - ) - attn.layer_id = layer_id - k_cache, v_cache = manager.get_layer_cache(layer_id) - attn.k_cache = k_cache - attn.v_cache = v_cache - layers.append(attn.to(DEVICE)) - return layers - - -# ============================================================ -# Pre-fill CPU cache with random data -# ============================================================ - -def prefill_cpu_cache(manager, num_blocks): - """ - Fill CPU cache with random KV data for num_blocks blocks. - This simulates having already processed many chunks. - """ - offload_engine = manager.offload_engine - - for block_id in range(num_blocks): - # Generate random KV data for all layers - for layer_id in range(NUM_LAYERS): - k_data = torch.randn( - BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM, - dtype=DTYPE, device=DEVICE - ) - v_data = torch.randn( - BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM, - dtype=DTYPE, device=DEVICE - ) - - # Copy to CPU cache - offload_engine.k_cache_cpu[layer_id, block_id].copy_(k_data) - offload_engine.v_cache_cpu[layer_id, block_id].copy_(v_data) - - return list(range(num_blocks)) - - -# ============================================================ -# Simulate N-way Pipeline (mirrors attention.py logic) -# ============================================================ - -def simulate_nway_pipeline( - layer_id: int, - q_batched: torch.Tensor, - cpu_block_table: list, - load_slots: list, - offload_engine, - scale: float, -): - """ - Simulate N-way pipeline for a single layer. - This mirrors the logic in Attention._ring_buffer_pipeline_load(). - """ - num_blocks = len(cpu_block_table) - num_slots = len(load_slots) - - o_acc, lse_acc = None, None - - # Phase 1: Pre-load up to num_slots blocks - num_preload = min(num_slots, num_blocks) - torch.cuda.nvtx.range_push(f"Phase1_Preload: L{layer_id}") - for i in range(num_preload): - offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i]) - torch.cuda.nvtx.range_pop() - - # Phase 2: Main loop with compute_stream - compute_stream = offload_engine.compute_stream - - for block_idx in range(num_blocks): - torch.cuda.nvtx.range_push(f"Block: L{layer_id} B{block_idx}") - - current_slot = load_slots[block_idx % num_slots] - - # Wait for transfer - offload_engine.wait_slot_layer(current_slot, layer_id) - - # Compute on dedicated stream - with torch.cuda.stream(compute_stream): - torch.cuda.nvtx.range_push(f"FlashAttn: L{layer_id} B{block_idx}") - prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, layer_id) - prev_o, prev_lse = flash_attn_with_lse( - q_batched, prev_k, prev_v, - softmax_scale=scale, - causal=False, - ) - torch.cuda.nvtx.range_pop() - offload_engine.record_slot_compute_done(current_slot, layer_id) - - # Start next transfer (reuse current_slot) - next_block_idx = block_idx + num_slots - if next_block_idx < num_blocks: - offload_engine.load_to_slot_layer( - current_slot, layer_id, cpu_block_table[next_block_idx] - ) - - # Merge - with torch.cuda.stream(compute_stream): - if o_acc is None: - o_acc, lse_acc = prev_o, prev_lse - else: - o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) - - torch.cuda.nvtx.range_pop() - - return o_acc, lse_acc - - -def simulate_full_forward(layers, manager, cpu_block_table, chunk_size): - """ - Simulate forward pass through all layers, loading previous blocks from CPU. - This is the key test: many blocks loaded via N-way pipeline. - """ - offload_engine = manager.offload_engine - - # Current chunk index (we're processing the "next" chunk after all prefilled ones) - current_chunk_idx = len(cpu_block_table) - write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx) - load_slots = offload_engine.get_load_slots_for_prefill(write_slot) - - # Random query for attention - q = torch.randn(1, chunk_size, NUM_HEADS, HEAD_DIM, dtype=DTYPE, device=DEVICE) - - outputs = [] - for layer in layers: - torch.cuda.nvtx.range_push(f"Layer: {layer.layer_id}") - - o_acc, lse_acc = simulate_nway_pipeline( - layer.layer_id, - q, - cpu_block_table, - load_slots, - offload_engine, - layer.scale, - ) - - outputs.append(o_acc) - torch.cuda.nvtx.range_pop() - - return outputs - - -# ============================================================ -# Main Test -# ============================================================ - -print("=" * 60) -print("Test: N-way Pipeline with CPU Offload") -print("=" * 60) - -# 1. Setup -print("\n[1] Creating manager and attention layers...") -manager = create_manager() -layers = create_attention_layers(manager) -offload_engine = manager.offload_engine - -print(f" - GPU slots: {NUM_GPU_SLOTS}") -print(f" - CPU blocks: {NUM_CPU_BLOCKS}") -print(f" - Per-slot streams: {len(offload_engine.slot_transfer_streams)}") -print(f" - Compute stream: {offload_engine.compute_stream}") - -# 2. Pre-fill CPU cache -NUM_PREV_BLOCKS = 12 # Many blocks to load via N-way pipeline -print(f"\n[2] Pre-filling {NUM_PREV_BLOCKS} blocks to CPU cache...") -cpu_block_table = prefill_cpu_cache(manager, NUM_PREV_BLOCKS) -print(f" - CPU blocks filled: {cpu_block_table}") - -# 3. Verify pipeline configuration -current_chunk_idx = NUM_PREV_BLOCKS -write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx) -load_slots = offload_engine.get_load_slots_for_prefill(write_slot) -print(f"\n[3] Pipeline configuration for chunk {current_chunk_idx}:") -print(f" - Write slot: {write_slot}") -print(f" - Load slots: {load_slots}") -print(f" - Pipeline depth (N-way): {len(load_slots)}") -assert len(load_slots) == NUM_GPU_SLOTS - 1, f"Expected {NUM_GPU_SLOTS - 1} load slots" - -# 4. Warmup -print("\n[4] Warmup (3 iterations)...") -for i in range(3): - outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE) - torch.cuda.synchronize() - print(f" - Warmup {i+1}/3 done") - -# 5. Benchmark -NUM_ITERS = 10 -print(f"\n[5] Benchmark ({NUM_ITERS} iterations)...") - -torch.cuda.synchronize() -start_event = torch.cuda.Event(enable_timing=True) -end_event = torch.cuda.Event(enable_timing=True) - -start_event.record() -for i in range(NUM_ITERS): - torch.cuda.nvtx.range_push(f"Iteration_{i}") - outputs = simulate_full_forward(layers, manager, cpu_block_table, CHUNK_SIZE) - torch.cuda.nvtx.range_pop() -end_event.record() - -torch.cuda.synchronize() -elapsed_ms = start_event.elapsed_time(end_event) - -# Stats -total_blocks_loaded = NUM_PREV_BLOCKS * NUM_LAYERS * NUM_ITERS -blocks_per_sec = total_blocks_loaded / (elapsed_ms / 1000) -total_tokens = NUM_PREV_BLOCKS * BLOCK_SIZE * NUM_LAYERS * NUM_ITERS -tokens_per_sec = total_tokens / (elapsed_ms / 1000) - -print(f"\n[6] Results:") -print(f" - Total time: {elapsed_ms:.2f} ms") -print(f" - Per iteration: {elapsed_ms / NUM_ITERS:.2f} ms") -print(f" - Blocks loaded: {total_blocks_loaded} ({blocks_per_sec:.0f} blocks/s)") -print(f" - Tokens processed: {total_tokens} ({tokens_per_sec:.0f} tok/s)") - -# 7. Verification -print("\n[7] Verification:") -assert len(outputs) == NUM_LAYERS, f"Expected {NUM_LAYERS} outputs" -for i, o in enumerate(outputs): - assert o is not None, f"Layer {i} output is None" - assert o.shape == (1, CHUNK_SIZE, NUM_HEADS, HEAD_DIM), f"Layer {i} shape mismatch" -print(" - All layer outputs valid ✓") -print(" - N-way pipeline executed correctly ✓") - -# Cleanup -reset_context() - -print("\n" + "=" * 60) -print("test_attention_offload: PASSED") -print("=" * 60) diff --git a/tests/test_chunked_attention.py b/tests/test_chunked_attention.py deleted file mode 100644 index 08928e3..0000000 --- a/tests/test_chunked_attention.py +++ /dev/null @@ -1,169 +0,0 @@ -""" -Test script for chunked attention correctness. - -Validates that chunked prefill using flash_attn_with_lse + merge_attention_outputs -produces the same result as full flash_attn_varlen_func. - -Scenario: Simulating chunked prefill where we process query chunk by chunk. -For each query chunk i: -- KV contains all tokens from chunk 0 to chunk i -- Previous KV chunks (0 to i-1): full attention (no causal mask) -- Current KV chunk (i): causal attention (diagonal block) -""" - -import torch -from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_func -from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs - -# ============================================================ -# Utility Functions -# ============================================================ - -def compute_chunked_prefill_for_chunk( - q_chunk: torch.Tensor, - kv_chunks: list, - current_chunk_idx: int, -) -> torch.Tensor: - """ - Compute attention for a single query chunk against all KV chunks up to current. - - This simulates chunked prefill for query chunk `current_chunk_idx`: - - KV chunks 0 to current_chunk_idx-1: full attention (all previous tokens visible) - - KV chunk current_chunk_idx: causal attention (diagonal block) - - Args: - q_chunk: [batch, chunk_size, nheads, headdim] - current query chunk - kv_chunks: List of (k, v) tuples, each [batch, chunk_size, nheads, headdim] - current_chunk_idx: Index of the current chunk being processed - - Returns: - out: [batch, chunk_size, nheads, headdim] - """ - accumulated_o = None - accumulated_lse = None - - for i in range(current_chunk_idx + 1): - k_chunk, v_chunk = kv_chunks[i] - - # Previous chunks: no causal mask (all tokens visible) - # Current chunk (diagonal): causal mask - is_diagonal = (i == current_chunk_idx) - - chunk_o, chunk_lse = flash_attn_with_lse( - q_chunk, k_chunk, v_chunk, causal=is_diagonal - ) - - if accumulated_o is None: - accumulated_o = chunk_o - accumulated_lse = chunk_lse - else: - accumulated_o, accumulated_lse = merge_attention_outputs( - accumulated_o, accumulated_lse, - chunk_o, chunk_lse - ) - - return accumulated_o - - -def compute_reference_causal( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, -) -> torch.Tensor: - """ - Compute reference causal attention using flash_attn_func. - - Args: - q, k, v: [batch, seqlen, nheads, headdim] - - Returns: - out: [batch, seqlen, nheads, headdim] - """ - return flash_attn_func(q, k, v, causal=True) - - -# ============================================================ -# Main Test Script -# ============================================================ - -torch.manual_seed(42) - -# Test configurations: (batch, num_chunks, chunk_size, nheads, headdim) -TEST_CASES = [ - (1, 4, 256, 8, 128), - (1, 4, 512, 8, 128), - (1, 8, 512, 8, 128), - (1, 32, 1024, 8, 128), - (1, 32, 1024, 32, 128), # More heads - (1, 32, 256, 8, 64), # Smaller head dim -] - -DTYPES = [torch.float16, torch.bfloat16] - -print("=" * 80) -print("Test: Chunked Prefill Attention vs Reference (flash_attn_func causal)") -print("=" * 80) -print("Simulating chunked prefill: Q chunk attends to all KV chunks up to current") -print(" - Previous KV chunks: full attention (no causal mask)") -print(" - Current KV chunk (diagonal): causal attention") -print() - -all_passed = True - -for dtype in DTYPES: - print(f"--- dtype: {dtype} ---") - - for batch, num_chunks, chunk_size, nheads, headdim in TEST_CASES: - seqlen = num_chunks * chunk_size - - # Generate full Q, K, V - q_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) - k_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) - v_full = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype) - - # Reference: full causal attention - out_ref = compute_reference_causal(q_full, k_full, v_full) - - # Split into chunks - q_chunks = [q_full[:, i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)] - kv_chunks = [ - (k_full[:, i*chunk_size:(i+1)*chunk_size], - v_full[:, i*chunk_size:(i+1)*chunk_size]) - for i in range(num_chunks) - ] - - # Compute chunked prefill for each query chunk - out_chunks = [] - for chunk_idx in range(num_chunks): - chunk_out = compute_chunked_prefill_for_chunk( - q_chunks[chunk_idx], - kv_chunks, - chunk_idx, - ) - out_chunks.append(chunk_out) - - # Concatenate chunked outputs - out_chunked = torch.cat(out_chunks, dim=1) - - # Compare - diff = (out_ref - out_chunked).abs() - max_diff = diff.max().item() - mean_diff = diff.mean().item() - - # Tolerance: fp16/bf16 have limited precision - tol = 1e-2 - passed = max_diff < tol - all_passed = all_passed and passed - - status = "PASS" if passed else "FAIL" - print( - f"[{status}] seqlen={seqlen:5d} chunks={num_chunks} " - f"chunk_size={chunk_size:4d} heads={nheads:2d} dim={headdim:3d} " - f"max_diff={max_diff:.6f} mean_diff={mean_diff:.8f}" - ) - - print() - -print("=" * 80) -assert all_passed, "Some tests failed!" -print("test_chunked_attention: PASSED") diff --git a/tests/test_chunked_decode_hook.py b/tests/test_chunked_decode_hook.py deleted file mode 100644 index e1dcfec..0000000 --- a/tests/test_chunked_decode_hook.py +++ /dev/null @@ -1,391 +0,0 @@ -""" -Correctness test for chunked decode attention. - -Captures Q and output during inference, then computes reference using -CPU KV cache with standard flash attention. -""" - -import os -os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" - -import torch -from random import randint, seed -from typing import Dict, List -from nanovllm import LLM, SamplingParams -from nanovllm.utils.context import get_context -from flash_attn.flash_attn_interface import flash_attn_func - -# Config -MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") -MAX_MODEL_LEN = 128 * 1024 -NUM_GPU_BLOCKS = 2 -INPUT_LEN = 16 * 1024 -NUM_DECODE_TOKENS = 5 -BLOCK_SIZE = 1024 - -# State -prefill_captures: List[Dict] = [] -decode_captures: List[Dict] = [] - - -def make_ones_injection_hook(): - """Inject Q=K=V=1.0 for deterministic testing.""" - def hook(module, inputs): - q, k, v = inputs[0], inputs[1], inputs[2] - q_ones = torch.ones_like(q) - k_ones = torch.ones_like(k) - v_ones = torch.ones_like(v) - return (q_ones, k_ones, v_ones) + inputs[3:] - return hook - - -def make_capture_hook(layer_id: int): - """Capture Q, K, V, output during inference.""" - def hook(module, inputs, output): - ctx = get_context() - q, k, v = inputs - - if ctx.is_prefill: - chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 - prefill_captures.append({ - 'layer_id': layer_id, - 'chunk_idx': chunk_idx, - 'q': q.clone().cpu(), - 'k': k.clone().cpu(), - 'v': v.clone().cpu(), - 'output': output.clone().cpu(), - }) - else: - decode_step = len([c for c in decode_captures if c['layer_id'] == layer_id]) - decode_captures.append({ - 'layer_id': layer_id, - 'decode_step': decode_step, - 'q': q.clone().cpu(), - 'k': k.clone().cpu(), - 'v': v.clone().cpu(), - 'output': output.clone().cpu(), - }) - return hook - - -def compute_decode_reference(layer_id: int, decode_step: int, scale: float, - k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor, - block_size: int, num_prefill_chunks: int) -> torch.Tensor: - """ - Compute reference decode output using CPU KV cache and standard flash attention. - - For decode, query attends to: - 1. All prefill KV (from CPU cache) - 2. All previous decode tokens (from captured decode k, v) - """ - # Get decode capture for this layer and step - decode_cap = None - for c in decode_captures: - if c['layer_id'] == layer_id and c['decode_step'] == decode_step: - decode_cap = c - break - - if decode_cap is None: - return None - - # Query: single decode token - q = decode_cap['q'].cuda() # [1, num_heads, head_dim] - q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim] - - # Collect all K, V: prefill chunks from captures + decode tokens from captures - # NOTE: We use prefill captures directly instead of CPU cache because - # the CPU block ID may not equal the chunk index. - all_k = [] - all_v = [] - - # 1. Prefill chunks from captures (use captured K/V, not CPU cache) - for cidx in range(num_prefill_chunks): - prefill_cap = None - for c in prefill_captures: - if c['layer_id'] == layer_id and c['chunk_idx'] == cidx: - prefill_cap = c - break - - if prefill_cap is not None: - # Use captured K/V directly (guaranteed to be correct layer data) - all_k.append(prefill_cap['k'].cuda()) - all_v.append(prefill_cap['v'].cuda()) - - # 2. Decode tokens from captures (up to and including current step) - for step in range(decode_step + 1): - for c in decode_captures: - if c['layer_id'] == layer_id and c['decode_step'] == step: - all_k.append(c['k'].cuda()) - all_v.append(c['v'].cuda()) - break - - if not all_k: - return None - - # Concatenate all K, V - full_k = torch.cat(all_k, dim=0).unsqueeze(0) # [1, total_len, kv_heads, head_dim] - full_v = torch.cat(all_v, dim=0).unsqueeze(0) - - # Run flash attention (non-causal since we explicitly control what KV to include) - output = flash_attn_func( - q_batched, full_k, full_v, - softmax_scale=scale, - causal=False, - ) - - return output.squeeze(0).squeeze(0).cpu() # [num_heads, head_dim] - - -# ============================================================ -# Main -# ============================================================ - -llm = LLM( - MODEL_PATH, - enforce_eager=True, - max_model_len=MAX_MODEL_LEN, - max_num_batched_tokens=MAX_MODEL_LEN, - enable_cpu_offload=True, - kvcache_block_size=BLOCK_SIZE, - num_gpu_blocks=NUM_GPU_BLOCKS, - dtype="float16", -) - -# Get model info -num_layers = len(llm.model_runner.model.model.layers) -head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim -scale = head_dim ** -0.5 - -# Register hooks -hooks = [] -for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers): - # Pre-hook: inject all ones for Q, K, V - # pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook()) - # hooks.append(pre_hook) - # Post-hook: capture Q, K, V, output - post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx)) - hooks.append(post_hook) - -# Run inference -seed(42) -prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] -outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=NUM_DECODE_TOKENS), use_tqdm=False) - -# Remove hooks -for hook in hooks: - hook.remove() - -# Get CPU cache reference -offload_engine = llm.model_runner.kvcache_manager.offload_engine -k_cache_cpu = offload_engine.k_cache_cpu.clone() -v_cache_cpu = offload_engine.v_cache_cpu.clone() - -# Calculate number of prefill chunks -num_prefill_chunks = INPUT_LEN // BLOCK_SIZE - -# Debug: Compare decode_buffer with captured K/V -print("\n=== DEBUG: Comparing decode_buffer with captured K/V ===") -decode_k_buffer = offload_engine.decode_k_buffer.clone().cpu() -for step in range(NUM_DECODE_TOKENS): - for layer_id in [0, 17, 35]: # Sample a few layers - # Find captured K for this step and layer - for c in decode_captures: - if c['layer_id'] == layer_id and c['decode_step'] == step: - captured_k = c['k'].squeeze(0) # [kv_heads, head_dim] - buffer_k = decode_k_buffer[layer_id, step] # [kv_heads, head_dim] - diff = (captured_k - buffer_k).abs().max().item() - print(f"Step {step}, Layer {layer_id}: captured vs buffer max_diff={diff:.6f}") - break - -# Debug: Verify that decode_buffer slices match concatenated captures -print("\n=== DEBUG: Verifying decode_buffer slices ===") -for layer_id in [0]: - for decode_step in [1, 2]: # Check steps that use multiple tokens - # Build expected slice from captures - expected_k_list = [] - for step in range(decode_step + 1): - for c in decode_captures: - if c['layer_id'] == layer_id and c['decode_step'] == step: - expected_k_list.append(c['k'].squeeze(0)) # [kv_heads, head_dim] - break - if expected_k_list: - expected_k = torch.stack(expected_k_list, dim=0) # [num_tokens, kv_heads, head_dim] - buffer_slice = decode_k_buffer[layer_id, 0:decode_step+1] - diff = (expected_k - buffer_slice).abs().max().item() - print(f"Decode step {decode_step}, Layer {layer_id}: buffer slice vs expected max_diff={diff:.6f}") - # Print first values - print(f" Buffer[0,0,0]={buffer_slice[0,0,0].item():.6f}, Expected[0,0,0]={expected_k[0,0,0].item():.6f}") - if decode_step >= 1: - print(f" Buffer[1,0,0]={buffer_slice[1,0,0].item():.6f}, Expected[1,0,0]={expected_k[1,0,0].item():.6f}") - -# Debug: Print expected K value for block 0, layer 0 (to compare with actual loading) -print("\n=== DEBUG: Expected K values for block 0, layer 0 ===") -for c in prefill_captures: - if c['layer_id'] == 0 and c['chunk_idx'] == 0: - print(f"Captured K[0,0,0] for layer 0, chunk 0: {c['k'][0,0,0].item():.6f}") - break -print(f"CPU cache K[0,0,0,0,0] for layer 0, block 0: {k_cache_cpu[0,0,0,0,0].item():.6f}") - -# Debug: Compare CPU cache with captured prefill K/V -print("\n=== DEBUG: Comparing CPU cache with captured prefill K/V ===") -for chunk_idx in [0, 7, 15]: # Sample a few chunks - for layer_id in [0, 17, 35]: # Sample a few layers - # Find captured K for this chunk and layer - for c in prefill_captures: - if c['layer_id'] == layer_id and c['chunk_idx'] == chunk_idx: - captured_k = c['k'] # [seq_len, kv_heads, head_dim] - cpu_cache_k = k_cache_cpu[layer_id, chunk_idx, :captured_k.shape[0]] - diff = (captured_k - cpu_cache_k).abs().max().item() - print(f"Chunk {chunk_idx}, Layer {layer_id}: captured vs CPU cache max_diff={diff:.6f}") - break - -# Debug: Get cpu_block_table to check order -kvcache_manager = llm.model_runner.kvcache_manager -# Find the sequence (it should still exist) -from nanovllm.engine.sequence import Sequence -for attr_name in ['sequences', '_sequences', 'active_sequences']: - if hasattr(kvcache_manager, attr_name): - print(f"Found {attr_name}") - break - -# Try to get cpu_block_table through a different way -print(f"\n=== DEBUG: CPU block order ===") -# For each prefill capture, check which CPU block it ended up in -for chunk_idx in range(num_prefill_chunks): - for c in prefill_captures: - if c['layer_id'] == 0 and c['chunk_idx'] == chunk_idx: - # Check if this chunk's K matches any CPU block - captured_k_first = c['k'][0, 0, 0].item() - for block_id in range(num_prefill_chunks): - cpu_k_first = k_cache_cpu[0, block_id, 0, 0, 0].item() - if abs(captured_k_first - cpu_k_first) < 1e-6: - print(f"Chunk {chunk_idx} -> CPU block {block_id}") - break - break - -# Debug: Check reference vs actual for decode steps 0 and 1 -# Also compute partial references (prefill only, decode only) to isolate the bug -from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs -for decode_step in [0, 1]: - print(f"\n=== DEBUG: Reference vs Actual for layer 0, decode {decode_step} ===") - layer_id = 0 - # Find the capture - for c in decode_captures: - if c['layer_id'] == layer_id and c['decode_step'] == decode_step: - q = c['q'].cuda() # [1, num_heads, head_dim] - q_batched = q.unsqueeze(0) # [1, 1, num_heads, head_dim] - - # Build prefill K/V per-block for block-by-block reference - prefill_k_blocks = [] - prefill_v_blocks = [] - for cidx in range(num_prefill_chunks): - for pc in prefill_captures: - if pc['layer_id'] == layer_id and pc['chunk_idx'] == cidx: - prefill_k_blocks.append(pc['k'].cuda().unsqueeze(0)) # [1, block_size, kv_heads, head_dim] - prefill_v_blocks.append(pc['v'].cuda().unsqueeze(0)) - break - - # Build decode K/V - decode_k_list = [] - decode_v_list = [] - for step in range(decode_step + 1): - for dc in decode_captures: - if dc['layer_id'] == layer_id and dc['decode_step'] == step: - decode_k_list.append(dc['k'].cuda()) - decode_v_list.append(dc['v'].cuda()) - break - - full_prefill_k = torch.cat([kb.squeeze(0) for kb in prefill_k_blocks], dim=0).unsqueeze(0) - full_prefill_v = torch.cat([vb.squeeze(0) for vb in prefill_v_blocks], dim=0).unsqueeze(0) - full_decode_k = torch.cat(decode_k_list, dim=0).unsqueeze(0) - full_decode_v = torch.cat(decode_v_list, dim=0).unsqueeze(0) - - full_k = torch.cat([full_prefill_k, full_decode_k], dim=1) - full_v = torch.cat([full_prefill_v, full_decode_v], dim=1) - - print(f"Q shape: {q_batched.shape}") - print(f"Prefill K shape: {full_prefill_k.shape}") - print(f"Decode K shape: {full_decode_k.shape}") - print(f"Full K shape: {full_k.shape}") - print(f"Total tokens: prefill={num_prefill_chunks * BLOCK_SIZE}, decode={decode_step + 1}") - - # Reference output (single attention over all) - ref_output = flash_attn_func( - q_batched, full_k, full_v, - softmax_scale=scale, - causal=False, - ) - - # Chunked reference: prefill attention + decode attention + merge - prefill_o, prefill_lse = flash_attn_with_lse( - q_batched, full_prefill_k, full_prefill_v, - softmax_scale=scale, - causal=False, - ) - decode_o, decode_lse = flash_attn_with_lse( - q_batched, full_decode_k, full_decode_v, - softmax_scale=scale, - causal=False, - ) - chunked_output, _ = merge_attention_outputs(prefill_o, prefill_lse, decode_o, decode_lse) - - # Block-by-block reference (simulating ring buffer pipeline) - block_o_acc, block_lse_acc = None, None - for bidx, (kb, vb) in enumerate(zip(prefill_k_blocks, prefill_v_blocks)): - o_blk, lse_blk = flash_attn_with_lse(q_batched, kb, vb, softmax_scale=scale, causal=False) - if block_o_acc is None: - block_o_acc, block_lse_acc = o_blk, lse_blk - else: - block_o_acc, block_lse_acc = merge_attention_outputs(block_o_acc, block_lse_acc, o_blk, lse_blk) - - # Compare block-by-block vs single - block_vs_single_diff = (block_o_acc - prefill_o).abs().max().item() - print(f"Block-by-block vs Single max_diff: {block_vs_single_diff:.6f}") - - # Compare full reference vs chunked reference - ref_vs_chunked_diff = (ref_output - chunked_output).abs().max().item() - print(f"Reference vs Chunked-reference max_diff: {ref_vs_chunked_diff:.6f}") - - ref_output = ref_output.squeeze(0).squeeze(0).cpu() - chunked_output_cpu = chunked_output.squeeze(0).squeeze(0).cpu() - - # Actual output - actual_output = c['output'].squeeze(0) - if actual_output.dim() == 3: - actual_output = actual_output.squeeze(0) - - diff_ref = (actual_output - ref_output).abs() - diff_chunked = (actual_output - chunked_output_cpu).abs() - print(f"Actual vs Reference max_diff: {diff_ref.max().item():.6f}") - print(f"Actual vs Chunked-ref max_diff: {diff_chunked.max().item():.6f}") - break -print() - -# Verify decode outputs -all_passed = True - -for c in decode_captures: - layer_id = c['layer_id'] - decode_step = c['decode_step'] - - ref_output = compute_decode_reference( - layer_id, decode_step, scale, - k_cache_cpu, v_cache_cpu, BLOCK_SIZE, num_prefill_chunks - ) - if ref_output is None: - continue - - actual_output = c['output'].squeeze(0) - if actual_output.dim() == 3: - actual_output = actual_output.squeeze(0) - - diff = (actual_output - ref_output).abs() - max_diff = diff.max().item() - - passed = max_diff < 1e-1 - all_passed = all_passed and passed - - if not passed: - print(f"[FAIL] Layer {layer_id}, Decode {decode_step}: max_diff={max_diff:.6f}") - -print(f"test_chunked_decode_hook: {'PASSED' if all_passed else 'FAILED'}") diff --git a/tests/test_chunked_prefill_hook.py b/tests/test_chunked_prefill_hook.py deleted file mode 100644 index cd00429..0000000 --- a/tests/test_chunked_prefill_hook.py +++ /dev/null @@ -1,196 +0,0 @@ -""" -Correctness test for chunked prefill attention. - -Captures Q and output during inference, then computes reference using -CPU KV cache with standard flash attention. -""" - -import os -os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" - -import torch -from random import randint, seed -from typing import Dict, List -from nanovllm import LLM, SamplingParams -from nanovllm.utils.context import get_context -from flash_attn.flash_attn_interface import flash_attn_varlen_func - -# Config -MODEL_PATH = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/") -MAX_MODEL_LEN = 128 * 1024 -NUM_GPU_BLOCKS = 2 -INPUT_LEN = 16 * 1024 -BLOCK_SIZE = 1024 - -# State - capture Q and output for each (layer, chunk) -captures: List[Dict] = [] - - -def make_ones_injection_hook(): - """Inject Q=K=V=1.0 for deterministic testing.""" - def hook(module, inputs): - ctx = get_context() - if not ctx.is_prefill: - return inputs - - q, k, v = inputs[0], inputs[1], inputs[2] - q_ones = torch.ones_like(q) - k_ones = torch.ones_like(k) - v_ones = torch.ones_like(v) - return (q_ones, k_ones, v_ones) + inputs[3:] - return hook - - -def make_capture_hook(layer_id: int): - """Capture Q and output during prefill.""" - def hook(module, inputs, output): - ctx = get_context() - if not ctx.is_prefill: - return - - q, k, v = inputs - chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 - - captures.append({ - 'layer_id': layer_id, - 'chunk_idx': chunk_idx, - 'q': q.clone().cpu(), - 'k': k.clone().cpu(), - 'v': v.clone().cpu(), - 'output': output.clone().cpu(), - }) - return hook - - -def compute_reference(layer_id: int, chunk_idx: int, scale: float, - k_cache_cpu: torch.Tensor, v_cache_cpu: torch.Tensor, - block_size: int) -> torch.Tensor: - """ - Compute reference output using CPU KV cache and standard flash attention. - - Concatenates all Q, K, V from chunks 0..chunk_idx and runs causal attention, - then extracts output for the current chunk. - """ - # Get all captures for this layer up to chunk_idx - layer_captures = [c for c in captures - if c['layer_id'] == layer_id and c['chunk_idx'] <= chunk_idx] - layer_captures = sorted(layer_captures, key=lambda x: x['chunk_idx']) - - if not layer_captures: - return None - - # Collect Q from captures, K/V from CPU cache - all_q = [] - all_k = [] - all_v = [] - chunk_lengths = [] - - for c in layer_captures: - cidx = c['chunk_idx'] - q = c['q'].cuda() # [seqlen, nheads, headdim] - all_q.append(q) - chunk_lengths.append(q.shape[0]) - - # Get K, V from CPU cache (already offloaded during prefill) - # CPU cache shape: [num_layers, num_blocks, block_size, kv_heads, head_dim] - k = k_cache_cpu[layer_id, cidx, :q.shape[0]].cuda() - v = v_cache_cpu[layer_id, cidx, :q.shape[0]].cuda() - all_k.append(k) - all_v.append(v) - - # Concatenate - full_q = torch.cat(all_q, dim=0) - full_k = torch.cat(all_k, dim=0) - full_v = torch.cat(all_v, dim=0) - total_len = full_q.shape[0] - - # Run standard causal flash attention - cu_seqlens = torch.tensor([0, total_len], dtype=torch.int32, device='cuda') - full_o = flash_attn_varlen_func( - full_q, full_k, full_v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=total_len, - max_seqlen_k=total_len, - softmax_scale=scale, - causal=True, - ) - - # Extract output for current chunk - start_pos = sum(chunk_lengths[:-1]) - end_pos = sum(chunk_lengths) - return full_o[start_pos:end_pos].cpu() - - -# ============================================================ -# Main -# ============================================================ - -llm = LLM( - MODEL_PATH, - enforce_eager=True, - max_model_len=MAX_MODEL_LEN, - max_num_batched_tokens=MAX_MODEL_LEN, - enable_cpu_offload=True, - kvcache_block_size=BLOCK_SIZE, - num_gpu_blocks=NUM_GPU_BLOCKS, - dtype="float16", -) - -# Get model info -num_layers = len(llm.model_runner.model.model.layers) -head_dim = llm.model_runner.model.model.layers[0].self_attn.attn.head_dim -scale = head_dim ** -0.5 - -# Register hooks -hooks = [] -for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers): - # Pre-hook: inject all ones for Q, K, V - # pre_hook = decoder_layer.self_attn.attn.register_forward_pre_hook(make_ones_injection_hook()) - # hooks.append(pre_hook) - # Post-hook: capture Q, K, V, output - post_hook = decoder_layer.self_attn.attn.register_forward_hook(make_capture_hook(layer_idx)) - hooks.append(post_hook) - -# Run inference -seed(42) -prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] -outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, max_tokens=1), use_tqdm=False) - -# Remove hooks -for hook in hooks: - hook.remove() - -# Get CPU cache reference -offload_engine = llm.model_runner.kvcache_manager.offload_engine -k_cache_cpu = offload_engine.k_cache_cpu.clone() -v_cache_cpu = offload_engine.v_cache_cpu.clone() - -# Verify: compare actual output with reference computed from CPU cache -all_passed = True -num_chunks = INPUT_LEN // BLOCK_SIZE - -for idx,c in enumerate(captures): - layer_id = c['layer_id'] - chunk_idx = c['chunk_idx'] - - # Skip chunk 0 (no previous KV to load) - if chunk_idx == 0: - continue - - ref_output = compute_reference(layer_id, chunk_idx, scale, k_cache_cpu, v_cache_cpu, BLOCK_SIZE) - if ref_output is None: - continue - - actual_output = c['output'] - diff = (actual_output - ref_output).abs() - max_diff = diff.max().item() - - passed = max_diff < 1e-1 # float16 tolerance - all_passed = all_passed and passed - - if not passed: - print(f"[FAIL] Layer {layer_id}, Chunk {chunk_idx}: max_diff={max_diff:.6f}") - __import__('pdb').set_trace() - -print(f"test_chunked_prefill_hook: {'PASSED' if all_passed else 'FAILED'}") diff --git a/tests/test_debug_verification.py b/tests/test_debug_verification.py deleted file mode 100644 index 532de2b..0000000 --- a/tests/test_debug_verification.py +++ /dev/null @@ -1,137 +0,0 @@ -""" -Test KV cache offload correctness using debug hooks. -Injects distinctive K/V values, verifies loaded tensors match expected patterns. -""" - -import os -os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" - -import inspect -from random import randint, seed -from typing import Dict, List -import torch -from torch import Tensor -from nanovllm import LLM, SamplingParams -from nanovllm.utils.context import get_context - -# Config -MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") -MAX_MODEL_LEN = 32 * 1024 -NUM_GPU_BLOCKS = 4 -INPUT_LEN = 32 * 1024 -BLOCK_SIZE = 1024 - -# State -load_log: List[Dict] = [] -current_chunk: List[int] = [0] - - -def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor, v: Tensor) -> None: - """Record loaded tensor values for layer 0.""" - if layer_id != 0: - return - - # Go up the stack to find kvcache_manager and print k_cache_gpu[*][0,0,0] for all slots - frame = inspect.currentframe() - try: - caller_frame = frame.f_back - if caller_frame is not None: - local_vars = caller_frame.f_locals - if 'self' in local_vars: - self_obj = local_vars['self'] - if hasattr(self_obj, 'k_cache_gpu'): - num_slots = self_obj.k_cache_gpu.shape[0] - vals = [] - for i in range(num_slots): - v = self_obj.k_cache_gpu[i][0,0,0].item() - if i == slot_idx: - vals.append(f"[{v}]") - else: - vals.append(str(v)) - print(f"[DEBUG] k_cache_gpu[0..{num_slots-1}][0,0,0] = [{', '.join(vals)}]") - finally: - del frame - - load_log.append({ - "chunk_idx": current_chunk[0], - "cpu_block_id": cpu_block_id, - "k_value": k.float().mean().item(), - }) - - -def make_pattern_injection_hook(layer_id): - """Inject K = chunk_idx + 1, V = -(chunk_idx + 1) for layer 0.""" - def hook(module, inputs): - ctx = get_context() - if not ctx.is_prefill or layer_id != 0: - return inputs - chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 - current_chunk[0] = chunk_idx - if len(inputs) >= 3: - q, k, v = inputs[0], inputs[1], inputs[2] - k_new = torch.full_like(k, float(chunk_idx + 1)) - v_new = torch.full_like(v, float(-(chunk_idx + 1))) - return (q, k_new, v_new) + inputs[3:] - return inputs - return hook - - -def verify() -> bool: - """Verify blocks loaded in correct order with correct K values.""" - chunk_loads: Dict[int, List[tuple]] = {} - for log in load_log: - chunk = log["chunk_idx"] - if chunk not in chunk_loads: - chunk_loads[chunk] = [] - chunk_loads[chunk].append((log["cpu_block_id"], log["k_value"])) - - for chunk, loads in chunk_loads.items(): - expected_blocks = list(range(chunk)) - actual_blocks = [b for b, _ in loads] - k_values = [k for _, k in loads] - expected_k = [float(b + 1) for b in expected_blocks] - - if actual_blocks != expected_blocks: - return False - if not all(abs(a - e) < 1e-2 for a, e in zip(k_values, expected_k)): - return False - return True - - -# Main -llm = LLM( - MODEL_PATH, - enforce_eager=True, - max_model_len=MAX_MODEL_LEN, - max_num_batched_tokens=MAX_MODEL_LEN, - enable_cpu_offload=True, - kvcache_block_size=BLOCK_SIZE, - num_gpu_blocks=NUM_GPU_BLOCKS, - dtype="float16", -) - -offload_engine = llm.model_runner.kvcache_manager.offload_engine -offload_engine.enable_debug_mode() -offload_engine.register_debug_hook(debug_load_hook) - -hooks = [] -for layer_idx, decoder_layer in enumerate(llm.model_runner.model.model.layers): - hooks.append(decoder_layer.self_attn.attn.register_forward_pre_hook( - make_pattern_injection_hook(layer_idx) - )) - -seed(42) -prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] -outputs = llm.generate(prompt_token_ids, SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1), use_tqdm=False) - -for hook in hooks: - hook.remove() -offload_engine.remove_debug_hook(debug_load_hook) -offload_engine.disable_debug_mode() - -# Verify -num_chunks = INPUT_LEN // BLOCK_SIZE -expected_loads = num_chunks * (num_chunks - 1) // 2 -passed = len(load_log) == expected_loads and verify() - -print(f"test_debug_verification: {'PASSED' if passed else 'FAILED'}") diff --git a/tests/test_flash_attn_kvcache.py b/tests/test_flash_attn_kvcache.py deleted file mode 100644 index d488d6a..0000000 --- a/tests/test_flash_attn_kvcache.py +++ /dev/null @@ -1,276 +0,0 @@ -""" -Test script for flash_attn_with_kvcache based chunked prefill. - -Verifies that chunked prefill produces identical results to full attention. -""" - -import torch -from flash_attn import flash_attn_func, flash_attn_with_kvcache - - -def chunk_prefill(q_full, k_full, v_full, k_cache, v_cache, cache_seqlens, chunk_size): - """ - Chunked prefill using flash_attn_with_kvcache. - - Args: - q_full, k_full, v_full: [batch, total_seq_len, heads, head_dim] - k_cache, v_cache: [batch, max_seq_len, kv_heads, head_dim] - cache_seqlens: [batch] - current cache lengths - chunk_size: size of each chunk - - Returns: - output: [batch, total_seq_len, heads, head_dim] - """ - total_len = q_full.shape[1] - outputs = [] - - for start in range(0, total_len, chunk_size): - end = min(start + chunk_size, total_len) - - q_chunk = q_full[:, start:end] - k_chunk = k_full[:, start:end] - v_chunk = v_full[:, start:end] - - out = flash_attn_with_kvcache( - q_chunk, - k_cache, - v_cache, - k=k_chunk, - v=v_chunk, - cache_seqlens=cache_seqlens, - causal=True, - ) - outputs.append(out) - - cache_seqlens += (end - start) - - return torch.cat(outputs, dim=1) - - -def reference_attention(q, k, v): - """Standard flash attention as reference.""" - return flash_attn_func(q, k, v, causal=True) - - -def test_chunked_prefill_correctness(): - """Test that chunked prefill matches full attention.""" - - batch_size = 1 - num_heads = 32 - num_kv_heads = 8 # GQA - head_dim = 128 - max_seq_len = 131072 # 128K - - test_configs = [ - (1024, 256), # 1K tokens, 256 chunk - (2048, 512), # 2K tokens, 512 chunk - (4096, 1024), # 4K tokens, 1K chunk - (4096, 2048), # 4K tokens, 2K chunk (2 chunks) - (8192, 2048), # 8K tokens, 2K chunk (4 chunks) - (16384, 4096), # 16K tokens, 4K chunk - (32768, 4096), # 32K tokens, 4K chunk - (65536, 8192), # 64K tokens, 8K chunk - (131072, 8192), # 128K tokens, 8K chunk (16 chunks) - ] - - for seq_len, chunk_size in test_configs: - print(f"\nTesting seq_len={seq_len}, chunk_size={chunk_size}...") - - # Generate random input - torch.manual_seed(42) - q = torch.randn(batch_size, seq_len, num_heads, head_dim, - dtype=torch.float16, device='cuda') - k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - - # Expand K/V for non-GQA reference - k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2) - v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2) - - # Reference: full attention - ref_out = reference_attention(q, k_expanded, v_expanded) - - # Chunked prefill with KV cache - k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda') - - chunked_out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size) - - # Compare - max_diff = (ref_out - chunked_out).abs().max().item() - mean_diff = (ref_out - chunked_out).abs().mean().item() - - # Verify cache was filled correctly - assert cache_seqlens[0].item() == seq_len, f"Cache seqlen mismatch: {cache_seqlens[0].item()} != {seq_len}" - - # Check K/V cache content - k_cache_diff = (k_cache[:, :seq_len] - k).abs().max().item() - v_cache_diff = (v_cache[:, :seq_len] - v).abs().max().item() - - print(f" Output max_diff: {max_diff:.6f}, mean_diff: {mean_diff:.6f}") - print(f" KV cache diff: k={k_cache_diff:.6f}, v={v_cache_diff:.6f}") - - # Tolerance for fp16 - tolerance = 1e-2 - if max_diff < tolerance: - print(f" PASSED") - else: - print(f" FAILED (max_diff {max_diff:.6f} >= {tolerance})") - return False - - return True - - -def test_incremental_decode(): - """Test that decode after chunked prefill works correctly.""" - - batch_size = 1 - num_heads = 32 - num_kv_heads = 8 - head_dim = 128 - max_seq_len = 8192 - - prefill_len = 2048 - chunk_size = 512 - num_decode_steps = 10 - - print(f"\nTesting incremental decode after chunked prefill...") - print(f" Prefill: {prefill_len} tokens, chunk_size={chunk_size}") - print(f" Decode: {num_decode_steps} steps") - - torch.manual_seed(42) - - # Prefill phase - q_prefill = torch.randn(batch_size, prefill_len, num_heads, head_dim, - dtype=torch.float16, device='cuda') - k_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - v_prefill = torch.randn(batch_size, prefill_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - - k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda') - - # Run chunked prefill - prefill_out = chunk_prefill(q_prefill, k_prefill, v_prefill, - k_cache, v_cache, cache_seqlens, chunk_size) - - print(f" After prefill: cache_seqlens={cache_seqlens[0].item()}") - - # Decode phase - one token at a time - for step in range(num_decode_steps): - q_decode = torch.randn(batch_size, 1, num_heads, head_dim, - dtype=torch.float16, device='cuda') - k_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - v_decode = torch.randn(batch_size, 1, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - - decode_out = flash_attn_with_kvcache( - q_decode, - k_cache, - v_cache, - k=k_decode, - v=v_decode, - cache_seqlens=cache_seqlens, - causal=True, - ) - - cache_seqlens += 1 - - assert decode_out.shape == (batch_size, 1, num_heads, head_dim) - - expected_len = prefill_len + num_decode_steps - actual_len = cache_seqlens[0].item() - - print(f" After decode: cache_seqlens={actual_len}") - - if actual_len == expected_len: - print(f" PASSED") - return True - else: - print(f" FAILED: expected {expected_len}, got {actual_len}") - return False - - -def test_batch_processing(): - """Test chunked prefill with batch > 1.""" - - batch_size = 4 - num_heads = 32 - num_kv_heads = 8 - head_dim = 128 - max_seq_len = 4096 - seq_len = 2048 - chunk_size = 512 - - print(f"\nTesting batch processing (batch_size={batch_size})...") - - torch.manual_seed(42) - - q = torch.randn(batch_size, seq_len, num_heads, head_dim, - dtype=torch.float16, device='cuda') - k = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - v = torch.randn(batch_size, seq_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - - k_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - v_cache = torch.zeros(batch_size, max_seq_len, num_kv_heads, head_dim, - dtype=torch.float16, device='cuda') - cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device='cuda') - - out = chunk_prefill(q, k, v, k_cache, v_cache, cache_seqlens, chunk_size) - - # Verify all batches have correct cache length - assert (cache_seqlens == seq_len).all(), f"Cache seqlens mismatch: {cache_seqlens}" - assert out.shape == (batch_size, seq_len, num_heads, head_dim) - - # Compare with reference for each batch item - k_expanded = k.repeat_interleave(num_heads // num_kv_heads, dim=2) - v_expanded = v.repeat_interleave(num_heads // num_kv_heads, dim=2) - ref_out = reference_attention(q, k_expanded, v_expanded) - - max_diff = (ref_out - out).abs().max().item() - - print(f" Output shape: {out.shape}") - print(f" Max diff vs reference: {max_diff:.6f}") - - if max_diff < 1e-2: - print(f" PASSED") - return True - else: - print(f" FAILED") - return False - - -# ============================================================ -# Main Test Script -# ============================================================ - -if __name__ == "__main__": - print("=" * 60) - print("Testing flash_attn_with_kvcache chunked prefill") - print("=" * 60) - - all_passed = True - - all_passed &= test_chunked_prefill_correctness() - all_passed &= test_incremental_decode() - all_passed &= test_batch_processing() - - print("\n" + "=" * 60) - if all_passed: - print("test_flash_attn_kvcache: ALL TESTS PASSED") - else: - print("test_flash_attn_kvcache: SOME TESTS FAILED") - print("=" * 60) diff --git a/tests/test_flashinfer_merge.py b/tests/test_flashinfer_merge.py deleted file mode 100644 index 7aa57a6..0000000 --- a/tests/test_flashinfer_merge.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -Test FlashInfer chunked attention with CPU offload. - -Uses single_prefill_with_kv_cache + merge_state for chunked KV processing. -""" - -import torch -import flashinfer - - -# ============================================================ -# Core Functions -# ============================================================ - -def chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk_size, kv_chunk_size): - """ - Chunked causal attention with KV on CPU. - - q: [seq_q, num_heads, head_dim] on GPU - k_cpu, v_cpu: [seq_kv, num_kv_heads, head_dim] on CPU - """ - seq_q = q.shape[0] - seq_kv = k_cpu.shape[0] - final_outputs = [] - - for q_start in range(0, seq_q, q_chunk_size): - q_end = min(q_start + q_chunk_size, seq_q) - q_chunk = q[q_start:q_end] - - merged_output = None - merged_lse = None - - for kv_start in range(0, seq_kv, kv_chunk_size): - kv_end = min(kv_start + kv_chunk_size, seq_kv) - - if kv_start >= q_end: - continue - - k_chunk = k_cpu[kv_start:kv_end].to(q.device, non_blocking=True) - v_chunk = v_cpu[kv_start:kv_end].to(q.device, non_blocking=True) - - causal = not (kv_end <= q_start) - partial_out, partial_lse = flashinfer.single_prefill_with_kv_cache( - q_chunk, k_chunk, v_chunk, - causal=causal, - return_lse=True, - ) - - if merged_output is None: - merged_output, merged_lse = partial_out, partial_lse - else: - merged_output, merged_lse = flashinfer.merge_state( - merged_output, merged_lse, - partial_out, partial_lse, - ) - - final_outputs.append(merged_output) - - return torch.cat(final_outputs, dim=0) - - -# ============================================================ -# Main Test Script -# ============================================================ - -print("=" * 60) -print("Testing FlashInfer chunked attention with CPU offload") -print("=" * 60) - -num_heads = 32 -num_kv_heads = 8 -head_dim = 128 - -test_configs = [ - (32768, 8192, 8192), # 32K tokens - (65536, 8192, 8192), # 64K tokens - (131072, 16384, 16384), # 128K tokens - # (262144, 16384, 16384), # 256K tokens (slow) - # (524288, 16384, 16384), # 512K tokens (slow) -] - -for seq_len, q_chunk, kv_chunk in test_configs: - torch.manual_seed(42) - - q = torch.randn(seq_len, num_heads, head_dim, dtype=torch.float16, device='cuda') - k_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu') - v_cpu = torch.randn(seq_len, num_kv_heads, head_dim, dtype=torch.float16, device='cpu') - - # Chunked result - chunked_out = chunked_prefill_causal(q, k_cpu, v_cpu, q_chunk, kv_chunk) - - # Reference - k_gpu = k_cpu.to('cuda') - v_gpu = v_cpu.to('cuda') - ref_out = flashinfer.single_prefill_with_kv_cache(q, k_gpu, v_gpu, causal=True) - - max_diff = (ref_out - chunked_out).abs().max().item() - mean_diff = (ref_out - chunked_out).abs().mean().item() - - num_chunks = (seq_len + q_chunk - 1) // q_chunk - assert max_diff < 1e-2, f"FAILED: max_diff={max_diff:.6f}" - print(f"seq={seq_len//1024}K, chunks={num_chunks}: max_diff={max_diff:.6f}, mean_diff={mean_diff:.6f}") - -print("\ntest_flashinfer_merge: PASSED") diff --git a/tests/test_nanovllm_steppable.py b/tests/test_nanovllm_steppable.py deleted file mode 100644 index 2381ef3..0000000 --- a/tests/test_nanovllm_steppable.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Test NanovllmSteppable: Print activation statistics at each layer. - -Usage: - python tests/test_nanovllm_steppable.py -""" - -import os -os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" - -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).parent)) - -import torch -from transformers import AutoTokenizer -from nanovllm import LLM -from nanovllm.debug.adapters.nanovllm_adapter import NanovllmSteppable -from utils import generate_needle_prompt, check_needle_answer - -# ============================================================ -# Config -# ============================================================ -MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") -INPUT_LEN = 32768 # Longer context to test offload -MAX_NEW_TOKENS = 20 -DTYPE = torch.float16 -ENABLE_CPU_OFFLOAD = True # Test offload mode - -# ============================================================ -# Load Model -# ============================================================ -print(f"Loading nanovllm model (cpu_offload={ENABLE_CPU_OFFLOAD})...") -llm = LLM( - MODEL_PATH, - enforce_eager=True, # Required for hooks to work - max_model_len=40960, - max_num_batched_tokens=40960, - enable_cpu_offload=ENABLE_CPU_OFFLOAD, - dtype="float16", -) -tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) - -# Get the underlying model for steppable -model = llm.model_runner.model - -# ============================================================ -# Prepare Input (using needle-in-haystack prompt) -# ============================================================ -prompt, expected_answer = generate_needle_prompt( - tokenizer, - target_length=INPUT_LEN, - needle_position=0.5, - needle_value="7492", - use_chat_template=False, - verbose=True, -) -input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda") -print(f"Input shape: {input_ids.shape}") -print(f"Expected answer: {expected_answer}\n") - -# ============================================================ -# Create Steppable Model (reused for prefill + decode) -# ============================================================ -steppable = NanovllmSteppable(model) - -# ============================================================ -# Prefill Phase: Print activation stats -# ============================================================ -print("=" * 85) -print("PREFILL PHASE") -print("=" * 85) -print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}") -print("-" * 85) - -current_ids = input_ids.clone() -logits = None - -for bp in steppable.step(current_ids, is_prefill=True): - t = bp.tensor.float() - shape_str = str(list(t.shape)) - print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}") - if bp.name == "LM Head": - logits = bp.tensor - -# Get first token from prefill -next_token_id = logits[0, -1].argmax().item() -next_token = tokenizer.decode(next_token_id) -current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1) -generated_tokens = [next_token] - -# ============================================================ -# Decode Phase: Only print generated tokens -# ============================================================ -print("\n" + "=" * 85) -print("DECODE PHASE") -print("=" * 85) -print(f"Step 1: {next_token!r}") - -for step in range(2, MAX_NEW_TOKENS + 1): - # Forward pass with full sequence (reuse same steppable) - # Note: nanovllm without KV cache needs full sequence for each decode - for bp in steppable.step(current_ids, is_prefill=True): - if bp.name == "LM Head": - logits = bp.tensor - - # Get next token (greedy) - next_token_id = logits[0, -1].argmax().item() - next_token = tokenizer.decode(next_token_id) - generated_tokens.append(next_token) - - print(f"Step {step:2}: {next_token!r}") - - # Stop if EOS - if next_token_id == tokenizer.eos_token_id: - print(" (EOS)") - break - - # Append to sequence - current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1) - -# ============================================================ -# Result -# ============================================================ -print("\n" + "=" * 85) -print("RESULT") -print("=" * 85) -generated_text = "".join(generated_tokens) -print(f"Generated: {generated_text!r}") -print(f"Expected: {expected_answer}") -print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}") - -print("\ntest_nanovllm_steppable: PASSED") diff --git a/tests/test_offload_correctness.py b/tests/test_offload_correctness.py deleted file mode 100644 index f8fb61d..0000000 --- a/tests/test_offload_correctness.py +++ /dev/null @@ -1,695 +0,0 @@ -""" -Test script to verify CPU offload correctness using distinctive KV patterns. - -Strategy: -1. Hook into attention forward pass -2. Overwrite K/V with distinctive patterns based on chunk_idx (e.g., K=chunk_idx, V=-chunk_idx) -3. After offload to CPU, verify CPU cache contains correct patterns -4. On subsequent chunks, verify loaded KV from CPU has correct patterns - -This catches bugs like: -- Wrong block being offloaded -- Wrong block being loaded -- Data corruption during transfer -""" - -import os -os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" - -import torch -from random import randint, seed -from nanovllm import LLM, SamplingParams -from nanovllm.utils.context import get_context - - -# ============================================================ -# Configuration -# ============================================================ - -MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") -MAX_MODEL_LEN = 64 * 1024 -NUM_GPU_BLOCKS = 4 -INPUT_LEN = 32 * 1024 # 32K tokens = 32 chunks (fits in 40 CPU blocks) -BLOCK_SIZE = 1024 - -# Test state -errors = [] -chunk_patterns = {} # chunk_idx -> (k_pattern, v_pattern) -block_coverage = {} # chunk_idx -> set of blocks that were actually computed -load_operations = [] # List of (chunk_idx, slot_id, cpu_block_id, k_ok, v_ok) tuples -current_chunk_for_load = [0] # Mutable container to track current chunk during loads - - -# ============================================================ -# Pattern Helpers -# ============================================================ - -def get_expected_pattern(chunk_idx: int): - """Get expected K/V pattern for a chunk.""" - # Use float values that are easy to identify - k_val = float(chunk_idx + 1) # 1.0, 2.0, 3.0, ... - v_val = float(-(chunk_idx + 1)) # -1.0, -2.0, -3.0, ... - return k_val, v_val - - -def fill_with_pattern(tensor: torch.Tensor, value: float): - """Fill tensor with a constant value.""" - tensor.fill_(value) - - -def check_pattern(tensor: torch.Tensor, expected: float, name: str, tolerance: float = 1e-3): - """Check if tensor contains expected pattern.""" - actual_mean = tensor.float().mean().item() - if abs(actual_mean - expected) > tolerance: - return False, f"{name}: expected mean={expected}, got {actual_mean}" - return True, None - - -# ============================================================ -# Load Verification Instrumentation -# ============================================================ - -_original_load_to_slot_layer = None -_offload_engine_ref = None - -def make_verified_load_to_slot_layer(original_func, offload_engine): - """ - Create a wrapper around load_to_slot_layer that verifies each load operation. - - After each H2D transfer, checks that the GPU slot contains the expected - pattern from the source CPU block. - """ - def verified_load(slot_idx: int, layer_id: int, cpu_block_id: int): - # Call original load - original_func(slot_idx, layer_id, cpu_block_id) - - # Only verify layer 0 to reduce overhead - if layer_id != 0: - return - - # IMPORTANT: Synchronize CUDA to ensure async transfer is complete - # The transfer happens on a per-slot stream, and wait_slot_layer only - # makes compute_stream wait. We need full sync to read on default stream. - torch.cuda.synchronize() - - # Get the expected pattern for this CPU block - # cpu_block_id == chunk_idx in our sequential test - expected_k, expected_v = get_expected_pattern(cpu_block_id) - - # Read GPU slot data (GPU cache has no layer dimension) - gpu_k = offload_engine.k_cache_gpu[slot_idx] - gpu_v = offload_engine.v_cache_gpu[slot_idx] - - actual_k = gpu_k.float().mean().item() - actual_v = gpu_v.float().mean().item() - - k_ok = abs(actual_k - expected_k) < 1e-3 - v_ok = abs(actual_v - expected_v) < 1e-3 - - chunk_idx = current_chunk_for_load[0] - load_operations.append({ - 'chunk_idx': chunk_idx, - 'slot_idx': slot_idx, - 'cpu_block_id': cpu_block_id, - 'expected_k': expected_k, - 'expected_v': expected_v, - 'actual_k': actual_k, - 'actual_v': actual_v, - 'k_ok': k_ok, - 'v_ok': v_ok, - }) - - if not (k_ok and v_ok): - errors.append(f"Load verification failed: chunk {chunk_idx}, " - f"CPU block {cpu_block_id} -> GPU slot {slot_idx}: " - f"expected K={expected_k:.1f}/V={expected_v:.1f}, " - f"got K={actual_k:.4f}/V={actual_v:.4f}") - - return verified_load - - -def install_load_verification(llm): - """Install verification wrapper on load_to_slot_layer.""" - global _original_load_to_slot_layer, _offload_engine_ref - - oe = llm.model_runner.kvcache_manager.offload_engine - _offload_engine_ref = oe - _original_load_to_slot_layer = oe.load_to_slot_layer - - oe.load_to_slot_layer = make_verified_load_to_slot_layer( - _original_load_to_slot_layer, oe - ) - print("Installed load verification wrapper on load_to_slot_layer") - - -def uninstall_load_verification(): - """Restore original load_to_slot_layer.""" - global _original_load_to_slot_layer, _offload_engine_ref - - if _offload_engine_ref is not None and _original_load_to_slot_layer is not None: - _offload_engine_ref.load_to_slot_layer = _original_load_to_slot_layer - print("Restored original load_to_slot_layer") - - _original_load_to_slot_layer = None - _offload_engine_ref = None - - -# ============================================================ -# Attention Hook -# ============================================================ - -def make_kv_pattern_pre_hook(layer_id: int): - """ - Create a PRE-forward hook that overwrites K/V with distinctive patterns BEFORE - they are stored to cache. This is called before attention.forward(). - - register_forward_pre_hook receives (module, inputs) and can modify inputs in-place. - """ - def hook(module, inputs): - ctx = get_context() - if not ctx.is_prefill: - return - - chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 - kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None - - if kvcache_manager is None: - return - - # Only process layer 0 for cleaner output - if layer_id != 0: - return - - q, k, v = inputs - k_pattern, v_pattern = get_expected_pattern(chunk_idx) - - # === Overwrite current chunk's K/V with distinctive pattern === - # This happens BEFORE forward(), so these values will be stored to cache - k.fill_(k_pattern) - v.fill_(v_pattern) - - # Only print for first few and last few chunks to reduce noise - num_chunks = INPUT_LEN // BLOCK_SIZE - if chunk_idx < 3 or chunk_idx >= num_chunks - 2: - print(f"[Chunk {chunk_idx:3d}] Set K={k_pattern:.1f}, V={v_pattern:.1f}") - elif chunk_idx == 3: - print(f"... (chunks 3 to {num_chunks - 3} omitted) ...") - - return hook - - -def make_block_coverage_pre_hook(layer_id: int): - """ - Create a PRE-forward hook to verify that all previous blocks are included - in the cpu_block_table for chunked prefill attention. - - This catches bugs where: - - Some blocks are missing from the computation - - Sparse policy incorrectly filters out blocks (when not intended) - - Block table construction has off-by-one errors - """ - def hook(module, inputs): - ctx = get_context() - if not ctx.is_prefill: - return - - chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 - kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None - - if kvcache_manager is None: - return - - # Only process layer 0 for cleaner output - if layer_id != 0: - return - - # Update current chunk for load verification tracking - current_chunk_for_load[0] = chunk_idx - - # No previous blocks for chunk 0 - if chunk_idx == 0: - return - - # Get the sequence and its block table (same logic as _chunked_prefill_attention) - seq = ctx.chunked_seq if hasattr(ctx, 'chunked_seq') else None - if seq is None: - return - - # Get the CPU block table that will be used for attention - cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq) - - # Expected blocks: 0 to chunk_idx-1 (all previous chunks) - expected_blocks = set(range(chunk_idx)) - actual_blocks = set(cpu_block_table) if cpu_block_table else set() - - # Store for later summary - block_coverage[chunk_idx] = { - 'expected': expected_blocks, - 'actual': actual_blocks, - } - - # Check for missing blocks - missing_blocks = expected_blocks - actual_blocks - extra_blocks = actual_blocks - expected_blocks - - num_chunks = INPUT_LEN // BLOCK_SIZE - if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or missing_blocks: - if not missing_blocks and not extra_blocks: - print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [OK]") - else: - status_parts = [] - if missing_blocks: - status_parts.append(f"MISSING {sorted(missing_blocks)}") - if extra_blocks: - status_parts.append(f"EXTRA {sorted(extra_blocks)}") - print(f" Block coverage chunk {chunk_idx:2d}: {len(actual_blocks)}/{len(expected_blocks)} blocks [{', '.join(status_parts)}]") - elif chunk_idx == 4: - # Indicate that middle chunks are being verified silently - print(f" ... (verifying chunks 4-{num_chunks - 3} silently) ...") - - if missing_blocks: - errors.append(f"Chunk {chunk_idx} missing blocks: {sorted(missing_blocks)}") - - return hook - - -def make_gpu_write_verification_post_hook(layer_id: int): - """ - Create a POST-forward hook to verify the current chunk's KV was correctly - written to the GPU ring buffer write_slot. - - This is a more reliable verification than checking load slots, because: - 1. Post-hook runs AFTER forward() writes to GPU cache - 2. write_slot mapping is deterministic: chunk_idx % num_ring_slots - 3. We injected known patterns in pre-hook, now verify they're in GPU cache - """ - def hook(module, inputs, output): - ctx = get_context() - if not ctx.is_prefill: - return - - chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 - kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None - - if kvcache_manager is None: - return - - # Only process layer 0 for cleaner output - if layer_id != 0: - return - - oe = kvcache_manager.offload_engine - num_ring_slots = oe.num_ring_slots - write_slot = chunk_idx % num_ring_slots - - # Get expected pattern for current chunk - expected_k, expected_v = get_expected_pattern(chunk_idx) - - # Verify write_slot contains current chunk's data (GPU cache has no layer dimension) - gpu_k = oe.k_cache_gpu[write_slot] - gpu_v = oe.v_cache_gpu[write_slot] - - actual_k_mean = gpu_k.float().mean().item() - actual_v_mean = gpu_v.float().mean().item() - - k_ok, _ = check_pattern(gpu_k, expected_k, f"GPU slot {write_slot}") - v_ok, _ = check_pattern(gpu_v, expected_v, f"GPU slot {write_slot}") - - num_chunks = INPUT_LEN // BLOCK_SIZE - # Print for first/last chunks, or if there's an error - if True or chunk_idx >= num_chunks - 2 or not (k_ok and v_ok): - if k_ok and v_ok: - print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: K={expected_k:.1f}, V={expected_v:.1f} [OK]") - else: - print(f" GPU write_slot[{write_slot}] chunk {chunk_idx:2d}: expected K={expected_k:.1f}/V={expected_v:.1f}, " - f"got K={actual_k_mean:.2f}/V={actual_v_mean:.2f} [FAIL]") - elif chunk_idx == 4: - print(f" ... (GPU write verification for chunks 4-{num_chunks - 3} silently) ...") - - if not (k_ok and v_ok): - errors.append(f"GPU write_slot {write_slot} at chunk {chunk_idx}: " - f"expected K={expected_k}, V={expected_v}, got K={actual_k_mean:.4f}, V={actual_v_mean:.4f}") - - return hook - - -def make_kv_verification_post_hook(layer_id: int): - """ - Create a POST-forward hook to verify CPU cache contains correct patterns - from previously offloaded blocks. - """ - def hook(module, inputs, output): - ctx = get_context() - if not ctx.is_prefill: - return - - chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 - kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None - - if kvcache_manager is None: - return - - # Only process layer 0 for cleaner output - if layer_id != 0: - return - - # === Verify previously offloaded blocks in CPU cache === - if chunk_idx >= 1: - oe = kvcache_manager.offload_engine - num_ok = 0 - num_fail = 0 - - # Check all previously offloaded blocks - for prev_chunk in range(chunk_idx): - # CPU block ID = prev_chunk (in simple sequential case) - cpu_block_id = prev_chunk - - # Get expected pattern for this block - expected_k, expected_v = get_expected_pattern(prev_chunk) - - # Read from CPU cache (layer 0) - cpu_k = oe.k_cache_cpu[layer_id, cpu_block_id] - cpu_v = oe.v_cache_cpu[layer_id, cpu_block_id] - - # Verify patterns - k_ok, k_err = check_pattern(cpu_k, expected_k, f"CPU K block {cpu_block_id}") - v_ok, v_err = check_pattern(cpu_v, expected_v, f"CPU V block {cpu_block_id}") - - if k_ok and v_ok: - num_ok += 1 - else: - num_fail += 1 - if k_err: - errors.append(k_err) - if v_err: - errors.append(v_err) - - # Only print summary for each chunk verification - num_chunks = INPUT_LEN // BLOCK_SIZE - if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or num_fail > 0: - status = "OK" if num_fail == 0 else f"FAIL({num_fail})" - print(f" CPU verify chunk {chunk_idx:2d}: {num_ok} blocks OK [{status}]") - elif chunk_idx == 4: - print(f" ... (CPU cache verification for chunks 4-{num_chunks - 3} silently) ...") - - return hook - - -def make_post_chunk_verification_hook(layer_id: int): - """ - Post-forward hook to verify GPU ring buffer state after attention. - """ - def hook(module, inputs, output): - ctx = get_context() - if not ctx.is_prefill or layer_id != 0: - return - - chunk_idx = ctx.current_chunk_idx if hasattr(ctx, 'current_chunk_idx') else 0 - kvcache_manager = ctx.kvcache_manager if hasattr(ctx, 'kvcache_manager') else None - - if kvcache_manager is None: - return - - oe = kvcache_manager.offload_engine - - # After attention, the current chunk's KV should be in the GPU ring buffer - # Ring slot = chunk_idx % num_ring_slots - ring_slot = chunk_idx % oe.num_ring_slots - - expected_k, expected_v = get_expected_pattern(chunk_idx) - - # Check GPU ring buffer (GPU cache has no layer dimension) - gpu_k = oe.k_cache_gpu[ring_slot] - gpu_v = oe.v_cache_gpu[ring_slot] - - k_ok, k_err = check_pattern(gpu_k, expected_k, f"GPU K slot {ring_slot}") - v_ok, v_err = check_pattern(gpu_v, expected_v, f"GPU V slot {ring_slot}") - - if k_ok and v_ok: - print(f" [OK] GPU slot {ring_slot} (chunk {chunk_idx}): K={expected_k}, V={expected_v}") - else: - if k_err: - print(f" [FAIL] {k_err}") - errors.append(k_err) - if v_err: - print(f" [FAIL] {v_err}") - errors.append(v_err) - - return hook - - -def register_hooks(llm): - """Register pre and post forward hooks.""" - hooks = [] - model = llm.model_runner.model - - for layer_idx, decoder_layer in enumerate(model.model.layers): - attn_module = decoder_layer.self_attn.attn - - # PRE-forward hook 1: Verify all previous blocks are in cpu_block_table - coverage_hook = attn_module.register_forward_pre_hook(make_block_coverage_pre_hook(layer_idx)) - hooks.append(coverage_hook) - - # PRE-forward hook 2: Inject K/V patterns before they're stored to cache - pattern_hook = attn_module.register_forward_pre_hook(make_kv_pattern_pre_hook(layer_idx)) - hooks.append(pattern_hook) - - # POST-forward hook 1: Verify GPU write_slot contains current chunk's data - gpu_verify_hook = attn_module.register_forward_hook(make_gpu_write_verification_post_hook(layer_idx)) - hooks.append(gpu_verify_hook) - - # POST-forward hook 2: Verify CPU cache contains correct patterns after offload - cpu_verify_hook = attn_module.register_forward_hook(make_kv_verification_post_hook(layer_idx)) - hooks.append(cpu_verify_hook) - - return hooks - - -# ============================================================ -# Final Verification -# ============================================================ - -def verify_final_cpu_state(llm, num_chunks: int): - """Verify all CPU blocks have correct patterns after prefill completes.""" - print("\n" + "=" * 60) - print("Final CPU Cache Verification") - print("=" * 60) - - kvcache_manager = llm.model_runner.kvcache_manager - oe = kvcache_manager.offload_engine - - num_ok = 0 - num_fail = 0 - fail_details = [] - - # After prefill, all chunks should be in CPU - for chunk_idx in range(num_chunks): - cpu_block_id = chunk_idx - expected_k, expected_v = get_expected_pattern(chunk_idx) - - # Check layer 0 - cpu_k = oe.k_cache_cpu[0, cpu_block_id] - cpu_v = oe.v_cache_cpu[0, cpu_block_id] - - k_ok, k_err = check_pattern(cpu_k, expected_k, f"Final CPU K block {cpu_block_id}") - v_ok, v_err = check_pattern(cpu_v, expected_v, f"Final CPU V block {cpu_block_id}") - - if k_ok and v_ok: - num_ok += 1 - # Only print first few and last few - if chunk_idx < 3 or chunk_idx >= num_chunks - 2: - actual_k_mean = cpu_k.float().mean().item() - actual_v_mean = cpu_v.float().mean().item() - print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), " - f"V={expected_v:.1f} ({actual_v_mean:.4f}) [OK]") - elif chunk_idx == 3: - print(f" ... (blocks 3 to {num_chunks - 3} verified OK) ...") - else: - num_fail += 1 - actual_k_mean = cpu_k.float().mean().item() - actual_v_mean = cpu_v.float().mean().item() - print(f" Block {cpu_block_id:3d}: K={expected_k:.1f} ({actual_k_mean:.4f}), " - f"V={expected_v:.1f} ({actual_v_mean:.4f}) [FAIL]") - if k_err: - errors.append(k_err) - if v_err: - errors.append(v_err) - - print(f"\nTotal: {num_ok} OK, {num_fail} FAIL out of {num_chunks} blocks") - - -def verify_block_coverage_summary(num_chunks: int): - """Verify that all chunks had complete block coverage during prefill.""" - print("\n" + "=" * 60) - print("Block Coverage Verification Summary") - print("=" * 60) - - num_ok = 0 - num_fail = 0 - total_blocks_expected = 0 - total_blocks_computed = 0 - - for chunk_idx in range(1, num_chunks): # Start from 1 (chunk 0 has no previous) - if chunk_idx not in block_coverage: - print(f" Chunk {chunk_idx}: NO COVERAGE DATA [FAIL]") - errors.append(f"Chunk {chunk_idx} has no block coverage data") - num_fail += 1 - continue - - coverage = block_coverage[chunk_idx] - expected = coverage['expected'] - actual = coverage['actual'] - missing = expected - actual - - total_blocks_expected += len(expected) - total_blocks_computed += len(actual) - - if not missing: - num_ok += 1 - else: - num_fail += 1 - - # Print summary - if num_fail == 0: - print(f" All {num_ok} chunks had complete block coverage [OK]") - print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})") - else: - print(f" {num_ok} chunks OK, {num_fail} chunks with missing blocks [FAIL]") - print(f" Total blocks computed: {total_blocks_computed} (expected: {total_blocks_expected})") - - # Verify the total is correct: sum of 0+1+2+...+(n-1) = n*(n-1)/2 - expected_total = num_chunks * (num_chunks - 1) // 2 - if total_blocks_expected == expected_total: - print(f" Expected total blocks matches formula: {expected_total} [OK]") - else: - print(f" Expected total mismatch: got {total_blocks_expected}, formula gives {expected_total} [FAIL]") - errors.append(f"Block coverage total mismatch") - - -def verify_load_operations_summary(num_chunks: int): - """Verify all H2D load operations transferred correct data.""" - print("\n" + "=" * 60) - print("H2D Load Operations Verification Summary") - print("=" * 60) - - if not load_operations: - print(" WARNING: No load operations recorded!") - print(" (This may indicate load verification was not installed)") - return - - num_ok = 0 - num_fail = 0 - loads_per_chunk = {} - - for op in load_operations: - chunk_idx = op['chunk_idx'] - if chunk_idx not in loads_per_chunk: - loads_per_chunk[chunk_idx] = [] - loads_per_chunk[chunk_idx].append(op) - - if op['k_ok'] and op['v_ok']: - num_ok += 1 - else: - num_fail += 1 - - # Print per-chunk summary for first/last chunks - for chunk_idx in sorted(loads_per_chunk.keys()): - ops = loads_per_chunk[chunk_idx] - chunk_ok = sum(1 for op in ops if op['k_ok'] and op['v_ok']) - chunk_fail = len(ops) - chunk_ok - - if chunk_idx < 4 or chunk_idx >= num_chunks - 2 or chunk_fail > 0: - # Show loaded block IDs in order - block_ids = [op['cpu_block_id'] for op in ops] - if chunk_fail == 0: - print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks {block_ids} [OK]") - else: - print(f" Chunk {chunk_idx:2d}: loaded {len(ops)} blocks, {chunk_fail} FAILED [FAIL]") - for op in ops: - if not (op['k_ok'] and op['v_ok']): - print(f" CPU block {op['cpu_block_id']} -> slot {op['slot_idx']}: " - f"expected K={op['expected_k']:.1f}/V={op['expected_v']:.1f}, " - f"got K={op['actual_k']:.4f}/V={op['actual_v']:.4f}") - elif chunk_idx == 4: - print(f" ... (chunks 4-{num_chunks - 3} load verification running silently) ...") - - # Print overall summary - print(f"\n Total load operations: {len(load_operations)}") - print(f" Successful: {num_ok}, Failed: {num_fail}") - - if num_fail == 0: - print(f" All H2D transfers verified correct [OK]") - else: - print(f" {num_fail} H2D transfers had incorrect data [FAIL]") - - -# ============================================================ -# Main Test Script -# ============================================================ - -if __name__ == "__main__": - print("=" * 60) - print("Test: CPU Offload Correctness with Distinctive KV Patterns") - print("=" * 60) - print(f"Input: {INPUT_LEN} tokens, {INPUT_LEN // BLOCK_SIZE} chunks") - print(f"GPU blocks: {NUM_GPU_BLOCKS}, Block size: {BLOCK_SIZE}") - print(f"Pattern: K=chunk_idx+1, V=-(chunk_idx+1)") - print() - - # 1. Initialize LLM - print("Initializing LLM...") - llm = LLM( - MODEL_PATH, - enforce_eager=True, - max_model_len=MAX_MODEL_LEN, - max_num_batched_tokens=MAX_MODEL_LEN, - enable_cpu_offload=True, - kvcache_block_size=BLOCK_SIZE, - num_gpu_blocks=NUM_GPU_BLOCKS, - dtype="float16", - ) - - # 2. Register hooks - hooks = register_hooks(llm) - print(f"Registered {len(hooks)} hooks") - - # 3. Install load verification (instrument load_to_slot_layer) - install_load_verification(llm) - - # 4. Generate prompt - seed(42) - prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] - num_chunks = INPUT_LEN // BLOCK_SIZE - - # 5. Run prefill - print("\n" + "=" * 60) - print("Running Prefill with KV Pattern Injection...") - print("=" * 60) - - sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1) - outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) - - # 6. Remove hooks and uninstall load verification - for hook in hooks: - hook.remove() - uninstall_load_verification() - - # 7. Final verification - verify_final_cpu_state(llm, num_chunks) - - # 8. Block coverage summary - verify_block_coverage_summary(num_chunks) - - # 9. H2D load operations summary - verify_load_operations_summary(num_chunks) - - # 10. Report results - print("\n" + "=" * 60) - if errors: - print(f"test_offload_correctness: FAILED ({len(errors)} errors)") - for err in errors[:10]: # Show first 10 errors - print(f" - {err}") - exit(1) - else: - print("test_offload_correctness: PASSED") - print("=" * 60) diff --git a/tests/test_offload_engine.py b/tests/test_offload_engine.py deleted file mode 100644 index 2df77bc..0000000 --- a/tests/test_offload_engine.py +++ /dev/null @@ -1,119 +0,0 @@ -""" -Test script for OffloadEngine - CPU-GPU KV cache transfer engine. - -Demonstrates: ring buffer, H2D/D2H transfers, CUDA events, KV access. -""" - -import torch -from nanovllm.kvcache.offload_engine import OffloadEngine - -# ============================================================ -# Utility Functions -# ============================================================ - -def verify(tensor: torch.Tensor, expected: float, name: str) -> None: - """Verify tensor contains expected value.""" - actual = tensor.mean().item() - assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}" - -# ============================================================ -# Configuration -# ============================================================ - -NUM_LAYERS = 4 -NUM_GPU_BLOCKS = 8 -NUM_CPU_BLOCKS = 16 -BLOCK_SIZE = 64 -NUM_KV_HEADS = 4 -HEAD_DIM = 32 - -# ============================================================ -# Main Test Script -# ============================================================ - -# 1. Initialize -engine = OffloadEngine( - num_layers=NUM_LAYERS, - num_gpu_blocks=NUM_GPU_BLOCKS, - num_cpu_blocks=NUM_CPU_BLOCKS, - block_size=BLOCK_SIZE, - num_kv_heads=NUM_KV_HEADS, - head_dim=HEAD_DIM, - dtype=torch.float16, -) - -# 2. Ring buffer slot management -for chunk_idx in range(12): - write_slot = engine.get_write_slot_for_prefill(chunk_idx) - load_slots = engine.get_load_slots_for_prefill(write_slot) - - print("chunk idx", chunk_idx, "write slots:", write_slot, "load slots:", load_slots) - - assert write_slot == chunk_idx % engine.num_ring_slots - assert write_slot not in load_slots - -assert engine.decode_slot == 0 -assert engine.get_load_slots_for_decode() == list(range(1, NUM_GPU_BLOCKS)) - -# 3. Per-slot per-layer H2D transfer -engine.k_cache_cpu[0, 0].fill_(42.0) -engine.v_cache_cpu[0, 0].fill_(42.5) - -engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=0) -engine.wait_slot_layer(slot_idx=1, layer_id=0) - -verify(engine.k_cache_gpu[0, 1], 42.0, "H2D K") -verify(engine.v_cache_gpu[0, 1], 42.5, "H2D V") - -# 4. Compute-done event (pipeline safety) -engine.record_slot_compute_done(slot_idx=1, layer_id=0) - -engine.k_cache_cpu[0, 1].fill_(100.0) -engine.v_cache_cpu[0, 1].fill_(100.5) -engine.load_to_slot_layer(slot_idx=1, layer_id=0, cpu_block_id=1) -engine.wait_slot_layer(slot_idx=1, layer_id=0) - -verify(engine.k_cache_gpu[0, 1], 100.0, "Reuse K") -verify(engine.v_cache_gpu[0, 1], 100.5, "Reuse V") - -# 5. D2H offload -engine.k_cache_gpu[1, 2].fill_(77.0) -engine.v_cache_gpu[1, 2].fill_(77.5) - -engine.offload_slot_to_cpu(slot_idx=2, cpu_block_id=5) -engine.wait_slot_offload(slot_idx=2) - -verify(engine.k_cache_cpu[1, 5], 77.0, "D2H K") -verify(engine.v_cache_cpu[1, 5], 77.5, "D2H V") - -# 6. KV access methods -k, v = engine.get_kv_for_slot(slot_idx=1, layer_id=0) -assert k.shape == (1, BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM) - -k, v = engine.get_kv_for_slots(layer_id=0, slot_indices=[0, 1, 2]) -assert k.shape == (1, 3 * BLOCK_SIZE, NUM_KV_HEADS, HEAD_DIM) - -engine.k_cache_gpu[0, engine.decode_slot].fill_(33.0) -k, v = engine.get_kv_for_decode_slot_accumulated(layer_id=0, num_tokens=10) -assert k.shape == (1, 10, NUM_KV_HEADS, HEAD_DIM) -verify(k, 33.0, "Decode slot K") - -# 7. Batch transfer -cpu_blocks = [2, 3, 4] -gpu_slots = [3, 4, 5] -for cpu_id in cpu_blocks: - engine.k_cache_cpu[0, cpu_id].fill_(50.0 + cpu_id) - -engine.load_cpu_blocks_to_gpu_slots(layer_id=0, cpu_block_ids=cpu_blocks, gpu_slot_ids=gpu_slots) - -for cpu_id, gpu_slot in zip(cpu_blocks, gpu_slots): - verify(engine.k_cache_gpu[0, gpu_slot], 50.0 + cpu_id, f"Batch slot {gpu_slot}") - -# 8. Gather indices (CUDA graph compatible) -engine.update_gather_indices(layer_id=0, mappings=[(0, 0), (1, 1), (2, 2)]) -assert engine.gather_indices_gpu[0, :3].tolist() == [0, 1, 2] - -engine.clear_gather_indices(layer_id=0) -assert engine.gather_indices_gpu[0, 0].item() == -1 - -print("test_offload_engine: PASSED") diff --git a/tests/test_prefill.py b/tests/test_prefill.py deleted file mode 100644 index 9e501c0..0000000 --- a/tests/test_prefill.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -Test script for chunked prefill with CPU offload. - -Demonstrates: LLM initialization, prefill execution with CPU offload enabled. -""" - -import os -os.environ["NANOVLLM_LOG_LEVEL"] = "DEBUG" - -from random import randint, seed -from nanovllm import LLM, SamplingParams - - -# ============================================================ -# Configuration -# ============================================================ - -MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") -MAX_MODEL_LEN = 32 * 1024 -NUM_GPU_BLOCKS = 2 -INPUT_LEN = 16 * 1024 - -# ============================================================ -# Main Test Script -# ============================================================ - -# 1. Initialize LLM with CPU offload -llm = LLM( - MODEL_PATH, - enforce_eager=True, - max_model_len=MAX_MODEL_LEN, - max_num_batched_tokens=MAX_MODEL_LEN, - enable_cpu_offload=True, - kvcache_block_size=1024, - num_gpu_blocks=NUM_GPU_BLOCKS, -) - -# 2. Generate random prompt tokens -seed(42) -prompt_token_ids = [[randint(0, 10000) for _ in range(INPUT_LEN)]] - -# 3. Run prefill (max_tokens=1 to focus on prefill only) -sampling_params = SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=1) -outputs = llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) - -# 4. Verify output -assert len(outputs) == 1 -assert "token_ids" in outputs[0] -assert len(outputs[0]["token_ids"]) == 1 - -print("test_prefill: PASSED") diff --git a/tests/test_sequential.py b/tests/test_sequential.py new file mode 100644 index 0000000..c2abb21 --- /dev/null +++ b/tests/test_sequential.py @@ -0,0 +1,199 @@ +""" +Sequential inference test for LLM. + +Tests: After completing one prompt, the system can correctly handle +a second prompt with a clean state (first prompt's KV cache deallocated). +""" + +import os +os.environ["NANOVLLM_LOG_LEVEL"] = "INFO" + +import argparse +from nanovllm import LLM, SamplingParams +from utils import generate_needle_prompt, check_needle_answer + + +def run_sequential_test( + model_path: str, + max_model_len: int, + input_len: int, + num_gpu_blocks: int = 4, + block_size: int = 1024, + enable_cpu_offload: bool = False, + verbose: bool = True, +) -> bool: + """ + Run sequential inference test with two different prompts. + + Each prompt has a different needle value. Both must be retrieved correctly. + """ + if verbose: + print(f"\n{'='*60}") + print(f"Sequential Inference Test") + print(f"{'='*60}") + print(f"Model: {model_path}") + print(f"Max model len: {max_model_len}") + print(f"Input length: {input_len}") + print(f"Block size: {block_size}") + print(f"CPU offload: {enable_cpu_offload}") + print(f"{'='*60}\n") + + # Initialize LLM once + llm_kwargs = { + "enforce_eager": True, + "max_model_len": max_model_len, + "max_num_batched_tokens": max_model_len, + "enable_cpu_offload": enable_cpu_offload, + "kvcache_block_size": block_size, + } + if enable_cpu_offload: + llm_kwargs["num_gpu_blocks"] = num_gpu_blocks + + llm = LLM(model_path, **llm_kwargs) + + sampling_params = SamplingParams( + temperature=0.6, + max_tokens=32, + ) + + # ============================================================ + # Test 1: First prompt with needle value "1234" + # ============================================================ + needle_value_1 = "1234" + if verbose: + print(f"\n[Test 1] Generating prompt with needle value: {needle_value_1}") + + prompt_1, expected_1 = generate_needle_prompt( + tokenizer=llm.tokenizer, + target_length=input_len, + needle_position=0.5, + needle_value=needle_value_1, + ) + + outputs_1 = llm.generate([prompt_1], sampling_params, use_tqdm=True) + output_text_1 = outputs_1[0]["text"] + passed_1 = check_needle_answer(output_text_1, expected_1) + + if verbose: + print(f" Expected: {expected_1}") + print(f" Output: {output_text_1[:100]}...") + print(f" Status: {'PASSED' if passed_1 else 'FAILED'}") + + # ============================================================ + # Test 2: Second prompt with needle value "5678" + # ============================================================ + needle_value_2 = "5678" + if verbose: + print(f"\n[Test 2] Generating prompt with needle value: {needle_value_2}") + + prompt_2, expected_2 = generate_needle_prompt( + tokenizer=llm.tokenizer, + target_length=input_len, + needle_position=0.5, + needle_value=needle_value_2, + ) + + outputs_2 = llm.generate([prompt_2], sampling_params, use_tqdm=True) + output_text_2 = outputs_2[0]["text"] + passed_2 = check_needle_answer(output_text_2, expected_2) + + if verbose: + print(f" Expected: {expected_2}") + print(f" Output: {output_text_2[:100]}...") + print(f" Status: {'PASSED' if passed_2 else 'FAILED'}") + + # ============================================================ + # Test 3: Third prompt - repeat first needle to ensure no cross-contamination + # ============================================================ + needle_value_3 = "9999" + if verbose: + print(f"\n[Test 3] Generating prompt with needle value: {needle_value_3}") + + prompt_3, expected_3 = generate_needle_prompt( + tokenizer=llm.tokenizer, + target_length=input_len, + needle_position=0.5, + needle_value=needle_value_3, + ) + + outputs_3 = llm.generate([prompt_3], sampling_params, use_tqdm=True) + output_text_3 = outputs_3[0]["text"] + passed_3 = check_needle_answer(output_text_3, expected_3) + + if verbose: + print(f" Expected: {expected_3}") + print(f" Output: {output_text_3[:100]}...") + print(f" Status: {'PASSED' if passed_3 else 'FAILED'}") + + # ============================================================ + # Summary + # ============================================================ + all_passed = passed_1 and passed_2 and passed_3 + + if verbose: + print(f"\n{'='*60}") + print(f"Summary") + print(f"{'='*60}") + print(f"Test 1 (needle={needle_value_1}): {'PASSED' if passed_1 else 'FAILED'}") + print(f"Test 2 (needle={needle_value_2}): {'PASSED' if passed_2 else 'FAILED'}") + print(f"Test 3 (needle={needle_value_3}): {'PASSED' if passed_3 else 'FAILED'}") + print(f"Overall: {'PASSED' if all_passed else 'FAILED'}") + print(f"{'='*60}\n") + + return all_passed + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Sequential inference test") + parser.add_argument( + "--model", "-m", + type=str, + default=os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/"), + help="Path to model" + ) + parser.add_argument( + "--max-model-len", + type=int, + default=36 * 1024, + help="Maximum model context length" + ) + parser.add_argument( + "--input-len", + type=int, + default=8 * 1024, + help="Target input sequence length" + ) + parser.add_argument( + "--num-gpu-blocks", + type=int, + default=2, + help="Number of GPU blocks for CPU offload" + ) + parser.add_argument( + "--block-size", + type=int, + default=1024, + help="KV cache block size" + ) + parser.add_argument( + "--enable-offload", + action="store_true", + help="Enable CPU offload" + ) + args = parser.parse_args() + + passed = run_sequential_test( + model_path=args.model, + max_model_len=args.max_model_len, + input_len=args.input_len, + num_gpu_blocks=args.num_gpu_blocks, + block_size=args.block_size, + enable_cpu_offload=args.enable_offload, + verbose=True, + ) + + if passed: + print("test_sequential: PASSED") + else: + print("test_sequential: FAILED") + exit(1) diff --git a/tests/test_sgdma.py b/tests/test_sgdma.py deleted file mode 100644 index f00ad82..0000000 --- a/tests/test_sgdma.py +++ /dev/null @@ -1,176 +0,0 @@ -""" -Tests for CUDA sgDMA (cudaMemcpy2D) extension. - -Author: Zijie Tian -""" - -import torch -import time -from nanovllm.comm import memcpy_2d - -# ============================================================ -# Configuration -# ============================================================ - -class Config: - num_layers = 32 - num_blocks = 10 - block_size = 4096 - num_kv_heads = 8 - head_dim = 128 - dtype = torch.float16 - - @property - def features_per_block(self): - return self.block_size * self.num_kv_heads * self.head_dim - - @property - def bytes_per_block(self): - return self.features_per_block * self.dtype.itemsize - - @property - def bytes_per_layer(self): - return self.num_blocks * self.bytes_per_block - - -# ============================================================ -# Performance Benchmark -# ============================================================ - -def benchmark_sgdma(): - """Benchmark cudaMemcpy2D vs standard PyTorch methods.""" - print("\n=== Performance Benchmark ===") - - cfg = Config() - - print(f" Configuration:") - print(f" num_layers: {cfg.num_layers}") - print(f" num_blocks: {cfg.num_blocks}") - print(f" block_size: {cfg.block_size}") - print(f" dtype: {cfg.dtype}") - print(f" bytes_per_block: {cfg.bytes_per_block / 1024:.1f} KB") - print(f" total transfer size: {cfg.num_layers * cfg.bytes_per_block / 1024 / 1024:.1f} MB") - - num_iterations = 10 - warmup = 3 - test_block_id = 5 - - # Allocate memory - cpu_strided = torch.randn( - cfg.num_layers, - cfg.num_blocks, - cfg.features_per_block, - dtype=cfg.dtype, - pin_memory=True - ) - - # ======================================== - # Method A: cudaMemcpy2D with sgDMA - # ======================================== - gpu_buffer_a = torch.empty(cfg.num_layers, cfg.features_per_block, dtype=cfg.dtype, device='cuda') - - spitch = cfg.bytes_per_layer - dpitch = cfg.bytes_per_block - width = cfg.bytes_per_block - height = cfg.num_layers - src_view = cpu_strided[:, test_block_id, :] - - # Warmup - for _ in range(warmup): - memcpy_2d(gpu_buffer_a, src_view, dpitch, spitch, width, height, "h2d") - torch.cuda.synchronize() - - # Benchmark - start = time.perf_counter() - for _ in range(num_iterations): - memcpy_2d(gpu_buffer_a, src_view, dpitch, spitch, width, height, "h2d") - torch.cuda.synchronize() - elapsed_a = time.perf_counter() - start - - avg_time_a = elapsed_a / num_iterations * 1000 # ms - bandwidth_a = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_a - - print(f"\n Method A (cudaMemcpy2D sgDMA):") - print(f" Avg time: {avg_time_a:.3f} ms") - print(f" Bandwidth: {bandwidth_a:.2f} GB/s") - - # ======================================== - # Method B: PyTorch .cuda() on strided view - # ======================================== - # Warmup - for _ in range(warmup): - _ = cpu_strided[:, test_block_id, :].cuda() - torch.cuda.synchronize() - - # Benchmark - start = time.perf_counter() - for _ in range(num_iterations): - _ = cpu_strided[:, test_block_id, :].cuda() - torch.cuda.synchronize() - elapsed_b = time.perf_counter() - start - - avg_time_b = elapsed_b / num_iterations * 1000 # ms - bandwidth_b = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_b - - print(f"\n Method B (PyTorch .cuda() on strided):") - print(f" Avg time: {avg_time_b:.3f} ms") - print(f" Bandwidth: {bandwidth_b:.2f} GB/s") - - # ======================================== - # Method C: PyTorch .cuda() on contiguous (pinned) - # ======================================== - # Create contiguous version with pinned memory - cpu_contiguous = torch.empty( - cfg.num_layers, - cfg.features_per_block, - dtype=cfg.dtype, - pin_memory=True - ) - cpu_contiguous.copy_(cpu_strided[:, test_block_id, :]) - - # Warmup - for _ in range(warmup): - _ = cpu_contiguous.cuda() - torch.cuda.synchronize() - - # Benchmark - start = time.perf_counter() - for _ in range(num_iterations): - _ = cpu_contiguous.cuda() - torch.cuda.synchronize() - elapsed_c = time.perf_counter() - start - - avg_time_c = elapsed_c / num_iterations * 1000 # ms - bandwidth_c = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_c - - print(f"\n Method C (PyTorch .cuda() on contiguous):") - print(f" Avg time: {avg_time_c:.3f} ms") - print(f" Bandwidth: {bandwidth_c:.2f} GB/s") - - # Summary - print(f"\n ========================================") - print(f" Performance Summary:") - print(f" Method A vs Method B: {bandwidth_a / bandwidth_b:.2f}x speedup") - print(f" Method A vs Method C: {bandwidth_a / bandwidth_c * 100:.2f}%") - print(f" ========================================") - - -# ============================================================ -# Main -# ============================================================ - -if __name__ == "__main__": - print("=== CUDA sgDMA (cudaMemcpy2D) Benchmark ===") - - # Check CUDA availability - if not torch.cuda.is_available(): - print("CUDA not available. Skipping benchmark.") - exit(1) - - # Print GPU info - print(f"Using GPU: {torch.cuda.get_device_name()}") - - # Run benchmark - benchmark_sgdma() - - print("\n=== Benchmark Complete ===") diff --git a/tests/test_torch_steppable.py b/tests/test_torch_steppable.py deleted file mode 100644 index 63c38a7..0000000 --- a/tests/test_torch_steppable.py +++ /dev/null @@ -1,121 +0,0 @@ -""" -Test TorchSteppable: Print activation statistics at each layer. - -Usage: - python tests/test_torch_steppable.py -""" - -import os -import sys -from pathlib import Path - -sys.path.insert(0, str(Path(__file__).parent)) - -import torch -from transformers import AutoTokenizer -from modeling_qwen3 import Qwen3ForCausalLM -from nanovllm.debug.adapters.torch_adapter import TorchSteppable -from utils import generate_needle_prompt, check_needle_answer - -# ============================================================ -# Config -# ============================================================ -MODEL_PATH = os.path.expanduser("~/models/Qwen3-0.6B/") -INPUT_LEN = 512 -MAX_NEW_TOKENS = 20 -DTYPE = torch.float16 - -# ============================================================ -# Load Model -# ============================================================ -print("Loading model...") -tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) -model = Qwen3ForCausalLM.from_pretrained(MODEL_PATH, dtype=DTYPE) -model = model.to("cuda").eval() - -# ============================================================ -# Prepare Input (using needle-in-haystack prompt) -# ============================================================ -prompt, expected_answer = generate_needle_prompt( - tokenizer, - target_length=INPUT_LEN, - needle_position=0.5, - needle_value="7492", - use_chat_template=False, - verbose=True, -) -input_ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda") -print(f"Input shape: {input_ids.shape}") -print(f"Expected answer: {expected_answer}\n") - -# ============================================================ -# Create Steppable Model (reused for prefill + decode) -# ============================================================ -steppable = TorchSteppable(model) - -# ============================================================ -# Prefill Phase: Print activation stats -# ============================================================ -print("=" * 85) -print("PREFILL PHASE") -print("=" * 85) -print(f"{'Layer':<15} {'Shape':<25} {'Mean':>10} {'Std':>10} {'Min':>10} {'Max':>10}") -print("-" * 85) - -current_ids = input_ids.clone() -logits = None - -for bp in steppable.step(current_ids): - t = bp.tensor.float() - shape_str = str(list(t.shape)) - print(f"{bp.name:<15} {shape_str:<25} {t.mean():>10.4f} {t.std():>10.4f} {t.min():>10.4f} {t.max():>10.4f}") - if bp.name == "LM Head": - logits = bp.tensor - -# Get first token from prefill -next_token_id = logits[0, -1].argmax().item() -next_token = tokenizer.decode(next_token_id) -current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1) -generated_tokens = [next_token] - -# ============================================================ -# Decode Phase: Only print generated tokens -# ============================================================ -print("\n" + "=" * 85) -print("DECODE PHASE") -print("=" * 85) -print(f"Step 1: {next_token!r}") - -for step in range(2, MAX_NEW_TOKENS + 1): - # Forward pass (reuse same steppable) - for bp in steppable.step(current_ids): - if bp.name == "LM Head": - logits = bp.tensor - - # Get next token (greedy) - next_token_id = logits[0, -1].argmax().item() - next_token = tokenizer.decode(next_token_id) - generated_tokens.append(next_token) - - print(f"Step {step:2}: {next_token!r}") - - # Stop if EOS - if next_token_id == tokenizer.eos_token_id: - print(" (EOS)") - break - - # Append to sequence - current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=current_ids.device)], dim=1) - -# ============================================================ -# Result -# ============================================================ -print("\n" + "=" * 85) -print("RESULT") -print("=" * 85) -generated_text = "".join(generated_tokens) -print(f"Generated: {generated_text!r}") -print(f"Expected: {expected_answer}") -print(f"Answer: {'CORRECT!' if check_needle_answer(generated_text, expected_answer) else 'INCORRECT'}") - -print("\ntest_torch_steppable: PASSED")