[WIP] NEED to modify communication.

This commit is contained in:
Zijie Tian
2025-12-24 21:57:51 +08:00
parent 782437c486
commit 6ec1b23982
9 changed files with 462 additions and 2 deletions

View File

@@ -0,0 +1,40 @@
# Documentation Policy
## Do Not Create Unnecessary Documentation
**IMPORTANT**: Do NOT create extra markdown documentation files unless explicitly requested by the user.
### What NOT to do:
- ❌ Do NOT create README files proactively
- ❌ Do NOT create analysis documents (*.md) after completing tasks
- ❌ Do NOT create tutorial/guide documents
- ❌ Do NOT create summary documents
### What TO do:
- ✅ Only create documentation when user explicitly asks for it
- ✅ Provide information directly in conversation instead
- ✅ Update existing documentation if changes require it
- ✅ Add inline code comments where necessary
### Exceptions:
Documentation is acceptable ONLY when:
1. User explicitly requests "create a README" or "write documentation"
2. Updating existing documentation to reflect code changes
3. Adding inline comments/docstrings to code itself
### Examples:
**Bad** (Don't do this):
```
User: "Profile the code"
Assistant: [Creates profiling_results.md after profiling]
```
**Good** (Do this instead):
```
User: "Profile the code"
Assistant: [Runs profiling, shows results in conversation]
```

4
.gitignore vendored
View File

@@ -192,4 +192,6 @@ cython_debug/
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
# refer to https://docs.cursor.com/context/ignore-files # refer to https://docs.cursor.com/context/ignore-files
.cursorignore .cursorignore
.cursorindexingignore .cursorindexingignore
results/

View File

@@ -296,3 +296,70 @@ Assertion `index out of bounds: 0 <= ... < 40960` failed
| CPU Offload (bench_offload.py) | ~7,200 | ~3.5 | | CPU Offload (bench_offload.py) | ~7,200 | ~3.5 |
CPU offload trades performance for memory efficiency, enabling long-context inference on limited GPU memory. CPU offload trades performance for memory efficiency, enabling long-context inference on limited GPU memory.
## TODO: Performance Optimizations
### 1. Fix Non-Contiguous CPU Cache Layout (High Priority)
**Problem**: Device-to-Pageable transfers causing 16x slowdown in CPU offload.
**Root Cause**:
Current CPU cache layout `[num_layers, num_cpu_blocks, ...]` causes non-contiguous memory access when slicing `k_cache_cpu[:, cpu_block_id]`. Although the tensor is pinned, CUDA runtime falls back to slow pageable transfer path because the slice is non-contiguous.
**Evidence from Profiling** (`tests/test_pinned_transfer.py` + nsys):
```
Non-contiguous slice (current):
- Transfer type: Device -> Pageable
- Avg duration: 5.825 ms
- Bandwidth: 1.44 GB/s
Contiguous layout (optimized):
- Transfer type: Device -> Pinned
- Avg duration: 0.364 ms
- Bandwidth: 23.11 GB/s
Performance gain: 16x faster!
```
**Technical Details**:
- Pinned memory requires both `pin_memory=True` AND contiguous layout for fast DMA
- Non-contiguous slice forces CUDA to:
1. Allocate temporary pageable buffer on CPU
2. Copy non-contiguous data to buffer (CPU overhead)
3. Transfer from pageable buffer to GPU (slow path)
- PCIe DMA engine requires contiguous memory blocks for optimal throughput
**Solution**:
Change CPU cache tensor layout from:
```python
# Current (non-contiguous when accessing per-block):
k_cache_cpu = torch.zeros(
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cpu", pin_memory=True
)
# Access: k_cache_cpu[:, cpu_block_id] -> non-contiguous!
# Optimized (contiguous per-block access):
k_cache_cpu = torch.zeros(
num_cpu_blocks, num_layers, block_size, num_kv_heads, head_dim,
dtype=dtype, device="cpu", pin_memory=True
)
# Access: k_cache_cpu[cpu_block_id] -> contiguous!
```
**Files to modify**:
- `nanovllm/kvcache/offload_engine.py`:
- Lines 104-111: Change tensor allocation layout
- All methods accessing CPU cache: update indexing
- `load_to_slot_layer()`, `offload_slot_to_cpu()`, `offload_slot_layer_to_cpu()`
- Update any other code that accesses `k_cache_cpu`/`v_cache_cpu`
**Expected Impact**:
- 16x faster D2H transfers in CPU offload mode
- Overall prefill throughput improvement: ~2-3x (D2H is currently the bottleneck)
- No change to API or functionality, pure performance optimization
**Reference**:
- Test: `tests/test_pinned_transfer.py`
- Profiling: `results/nsys/pinned_transfer_20251224_213158.nsys-rep`
- Analysis: See traces showing Device->Pageable vs Device->Pinned

View File

@@ -8,6 +8,7 @@ Key design principles for CUDA Graph compatibility:
""" """
import torch import torch
import torch.cuda.nvtx
from torch import Tensor from torch import Tensor
from typing import Dict, List, Tuple, Optional from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass from dataclasses import dataclass
@@ -660,6 +661,7 @@ class OffloadEngine:
""" """
logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]") logger.debug(f"Ring load: layer={layer_id}, CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
torch.cuda.nvtx.range_push(f"H2D: L{layer_id} CPU[{cpu_block_id}]->Slot[{slot_idx}]")
with torch.cuda.stream(self.transfer_stream_main): with torch.cuda.stream(self.transfer_stream_main):
# Wait for previous compute on this slot to complete before overwriting # Wait for previous compute on this slot to complete before overwriting
# This prevents data race: transfer must not start until attention finishes reading # This prevents data race: transfer must not start until attention finishes reading
@@ -672,6 +674,7 @@ class OffloadEngine:
self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True self.v_cache_cpu[layer_id, cpu_block_id], non_blocking=True
) )
self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main) self.ring_slot_ready[slot_idx][layer_id].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None: def wait_slot_layer(self, slot_idx: int, layer_id: int) -> None:
""" """
@@ -718,6 +721,7 @@ class OffloadEngine:
""" """
logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]") logger.debug(f"Ring offload: GPU slot[{slot_idx}] -> CPU[{cpu_block_id}]")
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main): with torch.cuda.stream(self.transfer_stream_main):
self.transfer_stream_main.wait_stream(self.compute_stream) self.transfer_stream_main.wait_stream(self.compute_stream)
self.k_cache_cpu[:, cpu_block_id].copy_( self.k_cache_cpu[:, cpu_block_id].copy_(
@@ -727,6 +731,7 @@ class OffloadEngine:
self.v_cache_gpu[:, slot_idx], non_blocking=True self.v_cache_gpu[:, slot_idx], non_blocking=True
) )
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main) self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
torch.cuda.nvtx.range_pop()
def wait_slot_offload(self, slot_idx: int) -> None: def wait_slot_offload(self, slot_idx: int) -> None:
"""Wait for slot offload to complete.""" """Wait for slot offload to complete."""

View File

@@ -1,5 +1,6 @@
import logging import logging
import torch import torch
import torch.cuda.nvtx
from torch import nn from torch import nn
import triton import triton
import triton.language as tl import triton.language as tl
@@ -117,6 +118,9 @@ class Attention(nn.Module):
""" """
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
current_chunk_idx = context.current_chunk_idx
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
# q, k, v shape: [total_tokens, num_heads, head_dim] # q, k, v shape: [total_tokens, num_heads, head_dim]
# Reshape for flash attention: [batch, seq, heads, dim] # Reshape for flash attention: [batch, seq, heads, dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim] q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
@@ -128,7 +132,6 @@ class Attention(nn.Module):
kvcache_manager = context.kvcache_manager kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
current_chunk_idx = context.current_chunk_idx
if kvcache_manager is not None and seq is not None and self.layer_id >= 0: if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks from previous chunks) # Get prefilled CPU blocks (blocks from previous chunks)
@@ -170,6 +173,7 @@ class Attention(nn.Module):
) )
# Compute attention against current chunk's KV (with causal mask) # Compute attention against current chunk's KV (with causal mask)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse( current_o, current_lse = flash_attn_with_lse(
q_batched, q_batched,
k_batched, k_batched,
@@ -177,13 +181,17 @@ class Attention(nn.Module):
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
) )
torch.cuda.nvtx.range_pop()
# Merge with accumulated # Merge with accumulated
if o_acc is None: if o_acc is None:
final_o = current_o final_o = current_o
else: else:
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse) final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim] # Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0) return final_o.squeeze(0)
@@ -287,6 +295,8 @@ class Attention(nn.Module):
offload_engine.load_to_slot_layer(slot_A, self.layer_id, cpu_block_table[0]) offload_engine.load_to_slot_layer(slot_A, self.layer_id, cpu_block_table[0])
for block_idx in range(num_blocks): for block_idx in range(num_blocks):
torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}")
# Alternate between slot_A and slot_B # Alternate between slot_A and slot_B
current_slot = slot_A if block_idx % 2 == 0 else slot_B current_slot = slot_A if block_idx % 2 == 0 else slot_B
next_slot = slot_B if block_idx % 2 == 0 else slot_A next_slot = slot_B if block_idx % 2 == 0 else slot_A
@@ -300,12 +310,14 @@ class Attention(nn.Module):
offload_engine.load_to_slot_layer(next_slot, self.layer_id, cpu_block_table[block_idx + 1]) offload_engine.load_to_slot_layer(next_slot, self.layer_id, cpu_block_table[block_idx + 1])
# Compute attention on current slot's data # Compute attention on current slot's data
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id) prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse( prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v, q_batched, prev_k, prev_v,
softmax_scale=self.scale, softmax_scale=self.scale,
causal=False, causal=False,
) )
torch.cuda.nvtx.range_pop()
# Record compute done - this allows the next round to safely load into this slot # Record compute done - this allows the next round to safely load into this slot
offload_engine.record_slot_compute_done(current_slot, self.layer_id) offload_engine.record_slot_compute_done(current_slot, self.layer_id)
@@ -316,6 +328,8 @@ class Attention(nn.Module):
else: else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse) o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
torch.cuda.nvtx.range_pop() # PipelineBlock
return o_acc, lse_acc return o_acc, lse_acc
def _chunked_decode_attention( def _chunked_decode_attention(

71
scripts/export_traces.sh Executable file
View File

@@ -0,0 +1,71 @@
#!/bin/bash
# Export detailed profiling traces from nsys report
#
# Usage:
# bash scripts/export_traces.sh <nsys_report_file>
#
# Example:
# bash scripts/export_traces.sh results/nsys/attention_offload_20251224_205806.nsys-rep
set -e
if [ $# -eq 0 ]; then
echo "Usage: $0 <nsys_report_file>"
echo "Example: $0 results/nsys/attention_offload_20251224_205806.nsys-rep"
exit 1
fi
NSYS_REPORT="$1"
BASENAME=$(basename "$NSYS_REPORT" .nsys-rep)
OUTPUT_DIR="results/nsys/traces"
mkdir -p "$OUTPUT_DIR"
echo "============================================================"
echo "Exporting traces from: $NSYS_REPORT"
echo "Output directory: $OUTPUT_DIR"
echo "============================================================"
# Export NVTX Push/Pop trace (shows timeline and nesting)
echo ""
echo "[1/4] Exporting NVTX Push/Pop trace..."
nsys stats --report nvtx_pushpop_trace "$NSYS_REPORT" \
> "$OUTPUT_DIR/${BASENAME}_nvtx_trace.txt" 2>&1
echo " -> $OUTPUT_DIR/${BASENAME}_nvtx_trace.txt"
# Export CUDA GPU trace (shows kernel execution timeline)
echo ""
echo "[2/4] Exporting CUDA GPU trace..."
nsys stats --report cuda_gpu_trace "$NSYS_REPORT" \
> "$OUTPUT_DIR/${BASENAME}_cuda_gpu_trace.txt" 2>&1
echo " -> $OUTPUT_DIR/${BASENAME}_cuda_gpu_trace.txt"
# Export CUDA API trace (shows API calls)
echo ""
echo "[3/4] Exporting CUDA API trace..."
nsys stats --report cuda_api_trace "$NSYS_REPORT" \
> "$OUTPUT_DIR/${BASENAME}_cuda_api_trace.txt" 2>&1
echo " -> $OUTPUT_DIR/${BASENAME}_cuda_api_trace.txt"
# Export NVTX kernel summary (shows which kernels ran within NVTX ranges)
echo ""
echo "[4/4] Exporting NVTX kernel summary..."
nsys stats --report nvtx_kern_sum "$NSYS_REPORT" \
> "$OUTPUT_DIR/${BASENAME}_nvtx_kern_sum.txt" 2>&1
echo " -> $OUTPUT_DIR/${BASENAME}_nvtx_kern_sum.txt"
echo ""
echo "============================================================"
echo "Traces exported successfully!"
echo "============================================================"
echo ""
echo "Key files:"
echo " - nvtx_trace.txt: Timeline with NVTX markers (shows nesting and timing)"
echo " - cuda_gpu_trace.txt: GPU kernel execution timeline"
echo " - cuda_api_trace.txt: CUDA API call timeline"
echo " - nvtx_kern_sum.txt: Kernels grouped by NVTX ranges"
echo ""
echo "For visual analysis, open in Nsight Systems GUI:"
echo " nsight-sys $NSYS_REPORT"
echo "============================================================"

67
scripts/profile_offload.sh Executable file
View File

@@ -0,0 +1,67 @@
#!/bin/bash
# Profile test_attention_offload.py using NVIDIA Nsight Systems
#
# Usage:
# bash scripts/profile_offload.sh
#
# Output:
# results/nsys/attention_offload_<timestamp>.nsys-rep
#
# View results:
# nsight-sys results/nsys/attention_offload_<timestamp>.nsys-rep
set -e
# Configuration
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
OUTPUT_DIR="$PROJECT_ROOT/results/nsys"
TEST_SCRIPT="$PROJECT_ROOT/tests/test_attention_offload.py"
# Create output directory if needed
mkdir -p "$OUTPUT_DIR"
# Generate timestamp for unique filename
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
OUTPUT_FILE="$OUTPUT_DIR/attention_offload_$TIMESTAMP"
echo "============================================================"
echo "NVIDIA Nsight Systems Profiling"
echo "============================================================"
echo "Test script: $TEST_SCRIPT"
echo "Output file: $OUTPUT_FILE.nsys-rep"
echo ""
# nsys profile options:
# --trace=cuda,nvtx,osrt,cudnn,cublas : Trace CUDA API, NVTX markers, OS runtime, cuDNN, cuBLAS
# --cuda-memory-usage=true : Track CUDA memory allocations
# --stats=true : Generate summary statistics
# --force-overwrite=true : Overwrite existing output file
# --output=<path> : Output file path (without .nsys-rep extension)
echo "Running nsys profile..."
echo ""
nsys profile \
--trace=cuda,nvtx,osrt,cudnn,cublas \
--cuda-memory-usage=true \
--stats=true \
--force-overwrite=true \
--output="$OUTPUT_FILE" \
python "$TEST_SCRIPT"
echo ""
echo "============================================================"
echo "Profiling completed successfully!"
echo "============================================================"
echo "Output file: $OUTPUT_FILE.nsys-rep"
echo ""
echo "To view results in GUI:"
echo " nsight-sys $OUTPUT_FILE.nsys-rep"
echo ""
echo "To export statistics:"
echo " nsys stats --report cuda_api_sum $OUTPUT_FILE.nsys-rep"
echo " nsys stats --report cuda_gpu_kern_sum $OUTPUT_FILE.nsys-rep"
echo " nsys stats --report cuda_gpu_mem_size_sum $OUTPUT_FILE.nsys-rep"
echo "============================================================"

View File

@@ -0,0 +1,70 @@
"""
Test if slicing maintains pinned memory property.
"""
import torch
print("=" * 60)
print("Test: Pinned Memory Property with Slicing")
print("=" * 60)
# Create a pinned tensor with shape similar to k_cache_cpu
# [num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim]
tensor = torch.zeros(8, 16, 1024, 8, 64, dtype=torch.float16, device="cpu", pin_memory=True)
print(f"\n1. Original tensor:")
print(f" - Shape: {tensor.shape}")
print(f" - is_pinned(): {tensor.is_pinned()}")
print(f" - is_contiguous(): {tensor.is_contiguous()}")
# Test slicing operation (what we do in offload_slot_to_cpu)
slice_view = tensor[:, 0] # Same as k_cache_cpu[:, cpu_block_id]
print(f"\n2. Sliced tensor [:, 0]:")
print(f" - Shape: {slice_view.shape}")
print(f" - is_pinned(): {slice_view.is_pinned()}")
print(f" - is_contiguous(): {slice_view.is_contiguous()}")
# Test if contiguous() helps
contiguous_slice = tensor[:, 0].contiguous()
print(f"\n3. Contiguous slice [:, 0].contiguous():")
print(f" - Shape: {contiguous_slice.shape}")
print(f" - is_pinned(): {contiguous_slice.is_pinned()}")
print(f" - is_contiguous(): {contiguous_slice.is_contiguous()}")
# Test copy behavior
gpu_tensor = torch.zeros(8, 4, 1024, 8, 64, dtype=torch.float16, device="cuda")
gpu_slice = gpu_tensor[:, 0]
print(f"\n4. GPU tensor slice:")
print(f" - Shape: {gpu_slice.shape}")
print(f" - is_contiguous(): {gpu_slice.is_contiguous()}")
# Simulate the problematic copy operation
print(f"\n5. Testing copy operations:")
# Method 1: Direct slice copy (current approach - SLOW)
slice_dst = tensor[:, 1]
print(f" Method 1 (slice view): dst.is_pinned()={slice_dst.is_pinned()}")
# Method 2: Use contiguous destination
contiguous_dst = tensor[:, 2].contiguous()
print(f" Method 2 (contiguous): dst.is_pinned()={contiguous_dst.is_pinned()}")
print("\n" + "=" * 60)
print("Conclusion:")
print("=" * 60)
if not slice_view.is_pinned():
print("❌ Slicing LOSES pinned memory property!")
print(" This causes Device-to-Pageable transfers (SLOW)")
else:
print("✓ Slicing maintains pinned memory property")
if contiguous_slice.is_pinned():
print("✓ .contiguous() maintains pinned memory property")
else:
print("❌ .contiguous() also loses pinned memory property")
print("\n" + "=" * 60)

View File

@@ -0,0 +1,124 @@
"""
Test D2H transfer performance with pinned vs non-contiguous memory.
"""
import torch
import time
print("=" * 60)
print("Test: D2H Transfer Performance (for nsys profiling)")
print("=" * 60)
# Setup
num_layers = 8
num_blocks = 16
block_size = 1024
num_kv_heads = 8
head_dim = 64
# Allocate CPU cache (pinned)
k_cache_cpu = torch.zeros(
num_layers, num_blocks, block_size, num_kv_heads, head_dim,
dtype=torch.float16, device="cpu", pin_memory=True
)
# Allocate GPU cache
k_cache_gpu = torch.randn(
num_layers, 4, block_size, num_kv_heads, head_dim,
dtype=torch.float16, device="cuda"
)
# Warmup
print("\nWarmup...")
for _ in range(10):
k_cache_cpu[:, 0].copy_(k_cache_gpu[:, 0], non_blocking=True)
torch.cuda.synchronize()
print(f"\nTensor info:")
print(f" k_cache_cpu.is_pinned(): {k_cache_cpu.is_pinned()}")
print(f" k_cache_cpu.is_contiguous(): {k_cache_cpu.is_contiguous()}")
print(f" k_cache_cpu[:, 0].is_pinned(): {k_cache_cpu[:, 0].is_pinned()}")
print(f" k_cache_cpu[:, 0].is_contiguous(): {k_cache_cpu[:, 0].is_contiguous()}")
# Test 1: Non-contiguous slice (current approach)
print(f"\n" + "=" * 60)
print("Test 1: Non-contiguous slice copy (current approach)")
print("=" * 60)
NUM_ITERS = 50 # Reduced for profiling
torch.cuda.nvtx.range_push("Test1_NonContiguous")
times = []
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"D2H_NonContig_{i}")
start = time.perf_counter()
k_cache_cpu[:, i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
avg_time = sum(times) / len(times)
print(f"Average time: {avg_time * 1000:.3f} ms")
print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s")
# Test 2: Transpose to make dimension contiguous
print(f"\n" + "=" * 60)
print("Test 2: Transpose to contiguous dimension")
print("=" * 60)
# Reshape to [num_blocks, num_layers, block_size, num_kv_heads, head_dim]
k_cache_cpu_transposed = torch.zeros(
num_blocks, num_layers, block_size, num_kv_heads, head_dim,
dtype=torch.float16, device="cpu", pin_memory=True
)
print(f" k_cache_cpu_transposed[0].is_pinned(): {k_cache_cpu_transposed[0].is_pinned()}")
print(f" k_cache_cpu_transposed[0].is_contiguous(): {k_cache_cpu_transposed[0].is_contiguous()}")
torch.cuda.nvtx.range_push("Test2_Contiguous")
times = []
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"D2H_Contig_{i}")
start = time.perf_counter()
k_cache_cpu_transposed[i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
avg_time = sum(times) / len(times)
print(f"Average time: {avg_time * 1000:.3f} ms")
print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s")
# Test 3: Fully contiguous buffer
print(f"\n" + "=" * 60)
print("Test 3: Fully contiguous buffer")
print("=" * 60)
k_cache_cpu_flat = torch.zeros(
num_layers * block_size * num_kv_heads * head_dim,
dtype=torch.float16, device="cpu", pin_memory=True
)
print(f" k_cache_cpu_flat.is_pinned(): {k_cache_cpu_flat.is_pinned()}")
print(f" k_cache_cpu_flat.is_contiguous(): {k_cache_cpu_flat.is_contiguous()}")
torch.cuda.nvtx.range_push("Test3_FlatContiguous")
times = []
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"D2H_Flat_{i}")
start = time.perf_counter()
k_cache_cpu_flat.copy_(k_cache_gpu[:, 0].flatten(), non_blocking=True)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
avg_time = sum(times) / len(times)
print(f"Average time: {avg_time * 1000:.3f} ms")
print(f"Bandwidth: {k_cache_cpu_flat.numel() * 2 / avg_time / 1e9:.2f} GB/s")
print("\n" + "=" * 60)
print("test_pinned_transfer: PASSED")
print("=" * 60)