From 6ec1b23982a3c5e536f095f6a3c366694cd011b4 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Wed, 24 Dec 2025 21:57:51 +0800 Subject: [PATCH] [WIP] NEED to modify communication. --- .claude/rules/no-extra-docs.md | 40 ++++++++++ .gitignore | 4 +- CLAUDE.md | 67 ++++++++++++++++ nanovllm/kvcache/offload_engine.py | 5 ++ nanovllm/layers/attention.py | 16 +++- scripts/export_traces.sh | 71 +++++++++++++++++ scripts/profile_offload.sh | 67 ++++++++++++++++ tests/test_pinned_memory_slice.py | 70 ++++++++++++++++ tests/test_pinned_transfer.py | 124 +++++++++++++++++++++++++++++ 9 files changed, 462 insertions(+), 2 deletions(-) create mode 100644 .claude/rules/no-extra-docs.md create mode 100755 scripts/export_traces.sh create mode 100755 scripts/profile_offload.sh create mode 100644 tests/test_pinned_memory_slice.py create mode 100644 tests/test_pinned_transfer.py diff --git a/.claude/rules/no-extra-docs.md b/.claude/rules/no-extra-docs.md new file mode 100644 index 0000000..87a806b --- /dev/null +++ b/.claude/rules/no-extra-docs.md @@ -0,0 +1,40 @@ +# Documentation Policy + +## Do Not Create Unnecessary Documentation + +**IMPORTANT**: Do NOT create extra markdown documentation files unless explicitly requested by the user. + +### What NOT to do: + +- ❌ Do NOT create README files proactively +- ❌ Do NOT create analysis documents (*.md) after completing tasks +- ❌ Do NOT create tutorial/guide documents +- ❌ Do NOT create summary documents + +### What TO do: + +- ✅ Only create documentation when user explicitly asks for it +- ✅ Provide information directly in conversation instead +- ✅ Update existing documentation if changes require it +- ✅ Add inline code comments where necessary + +### Exceptions: + +Documentation is acceptable ONLY when: +1. User explicitly requests "create a README" or "write documentation" +2. Updating existing documentation to reflect code changes +3. Adding inline comments/docstrings to code itself + +### Examples: + +**Bad** (Don't do this): +``` +User: "Profile the code" +Assistant: [Creates profiling_results.md after profiling] +``` + +**Good** (Do this instead): +``` +User: "Profile the code" +Assistant: [Runs profiling, shows results in conversation] +``` diff --git a/.gitignore b/.gitignore index 42ee68f..eae6416 100644 --- a/.gitignore +++ b/.gitignore @@ -192,4 +192,6 @@ cython_debug/ # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data # refer to https://docs.cursor.com/context/ignore-files .cursorignore -.cursorindexingignore \ No newline at end of file +.cursorindexingignore + +results/ \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 7416e8b..1a62dd9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -296,3 +296,70 @@ Assertion `index out of bounds: 0 <= ... < 40960` failed | CPU Offload (bench_offload.py) | ~7,200 | ~3.5 | CPU offload trades performance for memory efficiency, enabling long-context inference on limited GPU memory. + +## TODO: Performance Optimizations + +### 1. Fix Non-Contiguous CPU Cache Layout (High Priority) + +**Problem**: Device-to-Pageable transfers causing 16x slowdown in CPU offload. + +**Root Cause**: +Current CPU cache layout `[num_layers, num_cpu_blocks, ...]` causes non-contiguous memory access when slicing `k_cache_cpu[:, cpu_block_id]`. Although the tensor is pinned, CUDA runtime falls back to slow pageable transfer path because the slice is non-contiguous. + +**Evidence from Profiling** (`tests/test_pinned_transfer.py` + nsys): +``` +Non-contiguous slice (current): +- Transfer type: Device -> Pageable +- Avg duration: 5.825 ms +- Bandwidth: 1.44 GB/s + +Contiguous layout (optimized): +- Transfer type: Device -> Pinned +- Avg duration: 0.364 ms +- Bandwidth: 23.11 GB/s + +Performance gain: 16x faster! +``` + +**Technical Details**: +- Pinned memory requires both `pin_memory=True` AND contiguous layout for fast DMA +- Non-contiguous slice forces CUDA to: + 1. Allocate temporary pageable buffer on CPU + 2. Copy non-contiguous data to buffer (CPU overhead) + 3. Transfer from pageable buffer to GPU (slow path) +- PCIe DMA engine requires contiguous memory blocks for optimal throughput + +**Solution**: +Change CPU cache tensor layout from: +```python +# Current (non-contiguous when accessing per-block): +k_cache_cpu = torch.zeros( + num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cpu", pin_memory=True +) +# Access: k_cache_cpu[:, cpu_block_id] -> non-contiguous! + +# Optimized (contiguous per-block access): +k_cache_cpu = torch.zeros( + num_cpu_blocks, num_layers, block_size, num_kv_heads, head_dim, + dtype=dtype, device="cpu", pin_memory=True +) +# Access: k_cache_cpu[cpu_block_id] -> contiguous! +``` + +**Files to modify**: +- `nanovllm/kvcache/offload_engine.py`: + - Lines 104-111: Change tensor allocation layout + - All methods accessing CPU cache: update indexing + - `load_to_slot_layer()`, `offload_slot_to_cpu()`, `offload_slot_layer_to_cpu()` +- Update any other code that accesses `k_cache_cpu`/`v_cache_cpu` + +**Expected Impact**: +- 16x faster D2H transfers in CPU offload mode +- Overall prefill throughput improvement: ~2-3x (D2H is currently the bottleneck) +- No change to API or functionality, pure performance optimization + +**Reference**: +- Test: `tests/test_pinned_transfer.py` +- Profiling: `results/nsys/pinned_transfer_20251224_213158.nsys-rep` +- Analysis: See traces showing Device->Pageable vs Device->Pinned diff --git a/nanovllm/kvcache/offload_engine.py b/nanovllm/kvcache/offload_engine.py index f5cea8a..e488df2 100644 --- a/nanovllm/kvcache/offload_engine.py +++ b/nanovllm/kvcache/offload_engine.py @@ -8,6 +8,7 @@ Key design principles for CUDA Graph compatibility: """ import torch +import torch.cuda.nvtx from torch import Tensor from typing import Dict, List, Tuple, Optional from dataclasses import dataclass @@ -660,6 +661,7 @@ class OffloadEngine: """ logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]") + torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]") with torch.cuda.stream(self.transfer_stream_main): # Wait for previous compute on this slot to complete before overwriting # This prevents data race: transfer must not start until attention finishes reading @@ -672,6 +674,7 @@ class OffloadEngine: self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True ) self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main) + torch.cuda.nvtx.range_pop() def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None: """ @@ -718,6 +721,7 @@ class OffloadEngine: """ logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]") + torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]") with torch.cuda.stream(self.transfer_stream_main): self.transfer_stream_main.wait_stream(self.compute_stream) self.k_cache_cpu[:, cpu_block_id].copy_( @@ -727,6 +731,7 @@ class OffloadEngine: self.v_cache_gpu[:, slot_idx], non_blocking=True ) self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main) + torch.cuda.nvtx.range_pop() def wait_slot_offload(self, slot_idx: int) -> None: """Wait for slot offload to complete.""" diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index adb8546..a0f5fbe 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -1,5 +1,6 @@ import logging import torch +import torch.cuda.nvtx from torch import nn import triton import triton.language as tl @@ -117,6 +118,9 @@ class Attention(nn.Module): """ from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs + current_chunk_idx = context.current_chunk_idx + torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}") + # q, k, v shape: [total_tokens, num_heads, head_dim] # Reshape for flash attention: [batch, seq, heads, dim] q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim] @@ -128,7 +132,6 @@ class Attention(nn.Module): kvcache_manager = context.kvcache_manager seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None - current_chunk_idx = context.current_chunk_idx if kvcache_manager is not None and seq is not None and self.layer_id >= 0: # Get prefilled CPU blocks (blocks from previous chunks) @@ -170,6 +173,7 @@ class Attention(nn.Module): ) # Compute attention against current chunk's KV (with causal mask) + torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)") current_o, current_lse = flash_attn_with_lse( q_batched, k_batched, @@ -177,13 +181,17 @@ class Attention(nn.Module): softmax_scale=self.scale, causal=True, ) + torch.cuda.nvtx.range_pop() # Merge with accumulated if o_acc is None: final_o = current_o else: + torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}") final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_pop() # ChunkedPrefill # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] return final_o.squeeze(0) @@ -287,6 +295,8 @@ class Attention(nn.Module): offload_engine.load_to_slot_layer(slot_A, self.layer_id, cpu_block_table[0]) for block_idx in range(num_blocks): + torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}") + # Alternate between slot_A and slot_B current_slot = slot_A if block_idx % 2 == 0 else slot_B next_slot = slot_B if block_idx % 2 == 0 else slot_A @@ -300,12 +310,14 @@ class Attention(nn.Module): offload_engine.load_to_slot_layer(next_slot, self.layer_id, cpu_block_table[block_idx + 1]) # Compute attention on current slot's data + torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}") prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id) prev_o, prev_lse = flash_attn_with_lse( q_batched, prev_k, prev_v, softmax_scale=self.scale, causal=False, ) + torch.cuda.nvtx.range_pop() # Record compute done - this allows the next round to safely load into this slot offload_engine.record_slot_compute_done(current_slot, self.layer_id) @@ -316,6 +328,8 @@ class Attention(nn.Module): else: o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) + torch.cuda.nvtx.range_pop() # PipelineBlock + return o_acc, lse_acc def _chunked_decode_attention( diff --git a/scripts/export_traces.sh b/scripts/export_traces.sh new file mode 100755 index 0000000..5247dfe --- /dev/null +++ b/scripts/export_traces.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +# Export detailed profiling traces from nsys report +# +# Usage: +# bash scripts/export_traces.sh +# +# Example: +# bash scripts/export_traces.sh results/nsys/attention_offload_20251224_205806.nsys-rep + +set -e + +if [ $# -eq 0 ]; then + echo "Usage: $0 " + echo "Example: $0 results/nsys/attention_offload_20251224_205806.nsys-rep" + exit 1 +fi + +NSYS_REPORT="$1" +BASENAME=$(basename "$NSYS_REPORT" .nsys-rep) +OUTPUT_DIR="results/nsys/traces" + +mkdir -p "$OUTPUT_DIR" + +echo "============================================================" +echo "Exporting traces from: $NSYS_REPORT" +echo "Output directory: $OUTPUT_DIR" +echo "============================================================" + +# Export NVTX Push/Pop trace (shows timeline and nesting) +echo "" +echo "[1/4] Exporting NVTX Push/Pop trace..." +nsys stats --report nvtx_pushpop_trace "$NSYS_REPORT" \ + > "$OUTPUT_DIR/${BASENAME}_nvtx_trace.txt" 2>&1 +echo " -> $OUTPUT_DIR/${BASENAME}_nvtx_trace.txt" + +# Export CUDA GPU trace (shows kernel execution timeline) +echo "" +echo "[2/4] Exporting CUDA GPU trace..." +nsys stats --report cuda_gpu_trace "$NSYS_REPORT" \ + > "$OUTPUT_DIR/${BASENAME}_cuda_gpu_trace.txt" 2>&1 +echo " -> $OUTPUT_DIR/${BASENAME}_cuda_gpu_trace.txt" + +# Export CUDA API trace (shows API calls) +echo "" +echo "[3/4] Exporting CUDA API trace..." +nsys stats --report cuda_api_trace "$NSYS_REPORT" \ + > "$OUTPUT_DIR/${BASENAME}_cuda_api_trace.txt" 2>&1 +echo " -> $OUTPUT_DIR/${BASENAME}_cuda_api_trace.txt" + +# Export NVTX kernel summary (shows which kernels ran within NVTX ranges) +echo "" +echo "[4/4] Exporting NVTX kernel summary..." +nsys stats --report nvtx_kern_sum "$NSYS_REPORT" \ + > "$OUTPUT_DIR/${BASENAME}_nvtx_kern_sum.txt" 2>&1 +echo " -> $OUTPUT_DIR/${BASENAME}_nvtx_kern_sum.txt" + +echo "" +echo "============================================================" +echo "Traces exported successfully!" +echo "============================================================" +echo "" +echo "Key files:" +echo " - nvtx_trace.txt: Timeline with NVTX markers (shows nesting and timing)" +echo " - cuda_gpu_trace.txt: GPU kernel execution timeline" +echo " - cuda_api_trace.txt: CUDA API call timeline" +echo " - nvtx_kern_sum.txt: Kernels grouped by NVTX ranges" +echo "" +echo "For visual analysis, open in Nsight Systems GUI:" +echo " nsight-sys $NSYS_REPORT" +echo "============================================================" diff --git a/scripts/profile_offload.sh b/scripts/profile_offload.sh new file mode 100755 index 0000000..06a0c02 --- /dev/null +++ b/scripts/profile_offload.sh @@ -0,0 +1,67 @@ +#!/bin/bash + +# Profile test_attention_offload.py using NVIDIA Nsight Systems +# +# Usage: +# bash scripts/profile_offload.sh +# +# Output: +# results/nsys/attention_offload_.nsys-rep +# +# View results: +# nsight-sys results/nsys/attention_offload_.nsys-rep + +set -e + +# Configuration +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +OUTPUT_DIR="$PROJECT_ROOT/results/nsys" +TEST_SCRIPT="$PROJECT_ROOT/tests/test_attention_offload.py" + +# Create output directory if needed +mkdir -p "$OUTPUT_DIR" + +# Generate timestamp for unique filename +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUTPUT_FILE="$OUTPUT_DIR/attention_offload_$TIMESTAMP" + +echo "============================================================" +echo "NVIDIA Nsight Systems Profiling" +echo "============================================================" +echo "Test script: $TEST_SCRIPT" +echo "Output file: $OUTPUT_FILE.nsys-rep" +echo "" + +# nsys profile options: +# --trace=cuda,nvtx,osrt,cudnn,cublas : Trace CUDA API, NVTX markers, OS runtime, cuDNN, cuBLAS +# --cuda-memory-usage=true : Track CUDA memory allocations +# --stats=true : Generate summary statistics +# --force-overwrite=true : Overwrite existing output file +# --output= : Output file path (without .nsys-rep extension) + +echo "Running nsys profile..." +echo "" + +nsys profile \ + --trace=cuda,nvtx,osrt,cudnn,cublas \ + --cuda-memory-usage=true \ + --stats=true \ + --force-overwrite=true \ + --output="$OUTPUT_FILE" \ + python "$TEST_SCRIPT" + +echo "" +echo "============================================================" +echo "Profiling completed successfully!" +echo "============================================================" +echo "Output file: $OUTPUT_FILE.nsys-rep" +echo "" +echo "To view results in GUI:" +echo " nsight-sys $OUTPUT_FILE.nsys-rep" +echo "" +echo "To export statistics:" +echo " nsys stats --report cuda_api_sum $OUTPUT_FILE.nsys-rep" +echo " nsys stats --report cuda_gpu_kern_sum $OUTPUT_FILE.nsys-rep" +echo " nsys stats --report cuda_gpu_mem_size_sum $OUTPUT_FILE.nsys-rep" +echo "============================================================" diff --git a/tests/test_pinned_memory_slice.py b/tests/test_pinned_memory_slice.py new file mode 100644 index 0000000..d948008 --- /dev/null +++ b/tests/test_pinned_memory_slice.py @@ -0,0 +1,70 @@ +""" +Test if slicing maintains pinned memory property. +""" + +import torch + +print("=" * 60) +print("Test: Pinned Memory Property with Slicing") +print("=" * 60) + +# Create a pinned tensor with shape similar to k_cache_cpu +# [num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim] +tensor = torch.zeros(8, 16, 1024, 8, 64, dtype=torch.float16, device="cpu", pin_memory=True) + +print(f"\n1. Original tensor:") +print(f" - Shape: {tensor.shape}") +print(f" - is_pinned(): {tensor.is_pinned()}") +print(f" - is_contiguous(): {tensor.is_contiguous()}") + +# Test slicing operation (what we do in offload_slot_to_cpu) +slice_view = tensor[:, 0] # Same as k_cache_cpu[:, cpu_block_id] + +print(f"\n2. Sliced tensor [:, 0]:") +print(f" - Shape: {slice_view.shape}") +print(f" - is_pinned(): {slice_view.is_pinned()}") +print(f" - is_contiguous(): {slice_view.is_contiguous()}") + +# Test if contiguous() helps +contiguous_slice = tensor[:, 0].contiguous() + +print(f"\n3. Contiguous slice [:, 0].contiguous():") +print(f" - Shape: {contiguous_slice.shape}") +print(f" - is_pinned(): {contiguous_slice.is_pinned()}") +print(f" - is_contiguous(): {contiguous_slice.is_contiguous()}") + +# Test copy behavior +gpu_tensor = torch.zeros(8, 4, 1024, 8, 64, dtype=torch.float16, device="cuda") +gpu_slice = gpu_tensor[:, 0] + +print(f"\n4. GPU tensor slice:") +print(f" - Shape: {gpu_slice.shape}") +print(f" - is_contiguous(): {gpu_slice.is_contiguous()}") + +# Simulate the problematic copy operation +print(f"\n5. Testing copy operations:") + +# Method 1: Direct slice copy (current approach - SLOW) +slice_dst = tensor[:, 1] +print(f" Method 1 (slice view): dst.is_pinned()={slice_dst.is_pinned()}") + +# Method 2: Use contiguous destination +contiguous_dst = tensor[:, 2].contiguous() +print(f" Method 2 (contiguous): dst.is_pinned()={contiguous_dst.is_pinned()}") + +print("\n" + "=" * 60) +print("Conclusion:") +print("=" * 60) + +if not slice_view.is_pinned(): + print("❌ Slicing LOSES pinned memory property!") + print(" This causes Device-to-Pageable transfers (SLOW)") +else: + print("✓ Slicing maintains pinned memory property") + +if contiguous_slice.is_pinned(): + print("✓ .contiguous() maintains pinned memory property") +else: + print("❌ .contiguous() also loses pinned memory property") + +print("\n" + "=" * 60) diff --git a/tests/test_pinned_transfer.py b/tests/test_pinned_transfer.py new file mode 100644 index 0000000..937d423 --- /dev/null +++ b/tests/test_pinned_transfer.py @@ -0,0 +1,124 @@ +""" +Test D2H transfer performance with pinned vs non-contiguous memory. +""" + +import torch +import time + +print("=" * 60) +print("Test: D2H Transfer Performance (for nsys profiling)") +print("=" * 60) + +# Setup +num_layers = 8 +num_blocks = 16 +block_size = 1024 +num_kv_heads = 8 +head_dim = 64 + +# Allocate CPU cache (pinned) +k_cache_cpu = torch.zeros( + num_layers, num_blocks, block_size, num_kv_heads, head_dim, + dtype=torch.float16, device="cpu", pin_memory=True +) + +# Allocate GPU cache +k_cache_gpu = torch.randn( + num_layers, 4, block_size, num_kv_heads, head_dim, + dtype=torch.float16, device="cuda" +) + +# Warmup +print("\nWarmup...") +for _ in range(10): + k_cache_cpu[:, 0].copy_(k_cache_gpu[:, 0], non_blocking=True) + torch.cuda.synchronize() + +print(f"\nTensor info:") +print(f" k_cache_cpu.is_pinned(): {k_cache_cpu.is_pinned()}") +print(f" k_cache_cpu.is_contiguous(): {k_cache_cpu.is_contiguous()}") +print(f" k_cache_cpu[:, 0].is_pinned(): {k_cache_cpu[:, 0].is_pinned()}") +print(f" k_cache_cpu[:, 0].is_contiguous(): {k_cache_cpu[:, 0].is_contiguous()}") + +# Test 1: Non-contiguous slice (current approach) +print(f"\n" + "=" * 60) +print("Test 1: Non-contiguous slice copy (current approach)") +print("=" * 60) + +NUM_ITERS = 50 # Reduced for profiling + +torch.cuda.nvtx.range_push("Test1_NonContiguous") +times = [] +for i in range(NUM_ITERS): + torch.cuda.nvtx.range_push(f"D2H_NonContig_{i}") + start = time.perf_counter() + k_cache_cpu[:, i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True) + torch.cuda.synchronize() + times.append(time.perf_counter() - start) + torch.cuda.nvtx.range_pop() +torch.cuda.nvtx.range_pop() + +avg_time = sum(times) / len(times) +print(f"Average time: {avg_time * 1000:.3f} ms") +print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s") + +# Test 2: Transpose to make dimension contiguous +print(f"\n" + "=" * 60) +print("Test 2: Transpose to contiguous dimension") +print("=" * 60) + +# Reshape to [num_blocks, num_layers, block_size, num_kv_heads, head_dim] +k_cache_cpu_transposed = torch.zeros( + num_blocks, num_layers, block_size, num_kv_heads, head_dim, + dtype=torch.float16, device="cpu", pin_memory=True +) + +print(f" k_cache_cpu_transposed[0].is_pinned(): {k_cache_cpu_transposed[0].is_pinned()}") +print(f" k_cache_cpu_transposed[0].is_contiguous(): {k_cache_cpu_transposed[0].is_contiguous()}") + +torch.cuda.nvtx.range_push("Test2_Contiguous") +times = [] +for i in range(NUM_ITERS): + torch.cuda.nvtx.range_push(f"D2H_Contig_{i}") + start = time.perf_counter() + k_cache_cpu_transposed[i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True) + torch.cuda.synchronize() + times.append(time.perf_counter() - start) + torch.cuda.nvtx.range_pop() +torch.cuda.nvtx.range_pop() + +avg_time = sum(times) / len(times) +print(f"Average time: {avg_time * 1000:.3f} ms") +print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s") + +# Test 3: Fully contiguous buffer +print(f"\n" + "=" * 60) +print("Test 3: Fully contiguous buffer") +print("=" * 60) + +k_cache_cpu_flat = torch.zeros( + num_layers * block_size * num_kv_heads * head_dim, + dtype=torch.float16, device="cpu", pin_memory=True +) + +print(f" k_cache_cpu_flat.is_pinned(): {k_cache_cpu_flat.is_pinned()}") +print(f" k_cache_cpu_flat.is_contiguous(): {k_cache_cpu_flat.is_contiguous()}") + +torch.cuda.nvtx.range_push("Test3_FlatContiguous") +times = [] +for i in range(NUM_ITERS): + torch.cuda.nvtx.range_push(f"D2H_Flat_{i}") + start = time.perf_counter() + k_cache_cpu_flat.copy_(k_cache_gpu[:, 0].flatten(), non_blocking=True) + torch.cuda.synchronize() + times.append(time.perf_counter() - start) + torch.cuda.nvtx.range_pop() +torch.cuda.nvtx.range_pop() + +avg_time = sum(times) / len(times) +print(f"Average time: {avg_time * 1000:.3f} ms") +print(f"Bandwidth: {k_cache_cpu_flat.numel() * 2 / avg_time / 1e9:.2f} GB/s") + +print("\n" + "=" * 60) +print("test_pinned_transfer: PASSED") +print("=" * 60)