[WIP] NEED to modify communication.

This commit is contained in:
Zijie Tian
2025-12-24 21:57:51 +08:00
parent 782437c486
commit 6ec1b23982
9 changed files with 462 additions and 2 deletions

View File

@@ -0,0 +1,70 @@
"""
Test if slicing maintains pinned memory property.
"""
import torch
print("=" * 60)
print("Test: Pinned Memory Property with Slicing")
print("=" * 60)
# Create a pinned tensor with shape similar to k_cache_cpu
# [num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim]
tensor = torch.zeros(8, 16, 1024, 8, 64, dtype=torch.float16, device="cpu", pin_memory=True)
print(f"\n1. Original tensor:")
print(f" - Shape: {tensor.shape}")
print(f" - is_pinned(): {tensor.is_pinned()}")
print(f" - is_contiguous(): {tensor.is_contiguous()}")
# Test slicing operation (what we do in offload_slot_to_cpu)
slice_view = tensor[:, 0] # Same as k_cache_cpu[:, cpu_block_id]
print(f"\n2. Sliced tensor [:, 0]:")
print(f" - Shape: {slice_view.shape}")
print(f" - is_pinned(): {slice_view.is_pinned()}")
print(f" - is_contiguous(): {slice_view.is_contiguous()}")
# Test if contiguous() helps
contiguous_slice = tensor[:, 0].contiguous()
print(f"\n3. Contiguous slice [:, 0].contiguous():")
print(f" - Shape: {contiguous_slice.shape}")
print(f" - is_pinned(): {contiguous_slice.is_pinned()}")
print(f" - is_contiguous(): {contiguous_slice.is_contiguous()}")
# Test copy behavior
gpu_tensor = torch.zeros(8, 4, 1024, 8, 64, dtype=torch.float16, device="cuda")
gpu_slice = gpu_tensor[:, 0]
print(f"\n4. GPU tensor slice:")
print(f" - Shape: {gpu_slice.shape}")
print(f" - is_contiguous(): {gpu_slice.is_contiguous()}")
# Simulate the problematic copy operation
print(f"\n5. Testing copy operations:")
# Method 1: Direct slice copy (current approach - SLOW)
slice_dst = tensor[:, 1]
print(f" Method 1 (slice view): dst.is_pinned()={slice_dst.is_pinned()}")
# Method 2: Use contiguous destination
contiguous_dst = tensor[:, 2].contiguous()
print(f" Method 2 (contiguous): dst.is_pinned()={contiguous_dst.is_pinned()}")
print("\n" + "=" * 60)
print("Conclusion:")
print("=" * 60)
if not slice_view.is_pinned():
print("❌ Slicing LOSES pinned memory property!")
print(" This causes Device-to-Pageable transfers (SLOW)")
else:
print("✓ Slicing maintains pinned memory property")
if contiguous_slice.is_pinned():
print("✓ .contiguous() maintains pinned memory property")
else:
print("❌ .contiguous() also loses pinned memory property")
print("\n" + "=" * 60)

View File

@@ -0,0 +1,124 @@
"""
Test D2H transfer performance with pinned vs non-contiguous memory.
"""
import torch
import time
print("=" * 60)
print("Test: D2H Transfer Performance (for nsys profiling)")
print("=" * 60)
# Setup
num_layers = 8
num_blocks = 16
block_size = 1024
num_kv_heads = 8
head_dim = 64
# Allocate CPU cache (pinned)
k_cache_cpu = torch.zeros(
num_layers, num_blocks, block_size, num_kv_heads, head_dim,
dtype=torch.float16, device="cpu", pin_memory=True
)
# Allocate GPU cache
k_cache_gpu = torch.randn(
num_layers, 4, block_size, num_kv_heads, head_dim,
dtype=torch.float16, device="cuda"
)
# Warmup
print("\nWarmup...")
for _ in range(10):
k_cache_cpu[:, 0].copy_(k_cache_gpu[:, 0], non_blocking=True)
torch.cuda.synchronize()
print(f"\nTensor info:")
print(f" k_cache_cpu.is_pinned(): {k_cache_cpu.is_pinned()}")
print(f" k_cache_cpu.is_contiguous(): {k_cache_cpu.is_contiguous()}")
print(f" k_cache_cpu[:, 0].is_pinned(): {k_cache_cpu[:, 0].is_pinned()}")
print(f" k_cache_cpu[:, 0].is_contiguous(): {k_cache_cpu[:, 0].is_contiguous()}")
# Test 1: Non-contiguous slice (current approach)
print(f"\n" + "=" * 60)
print("Test 1: Non-contiguous slice copy (current approach)")
print("=" * 60)
NUM_ITERS = 50 # Reduced for profiling
torch.cuda.nvtx.range_push("Test1_NonContiguous")
times = []
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"D2H_NonContig_{i}")
start = time.perf_counter()
k_cache_cpu[:, i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
avg_time = sum(times) / len(times)
print(f"Average time: {avg_time * 1000:.3f} ms")
print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s")
# Test 2: Transpose to make dimension contiguous
print(f"\n" + "=" * 60)
print("Test 2: Transpose to contiguous dimension")
print("=" * 60)
# Reshape to [num_blocks, num_layers, block_size, num_kv_heads, head_dim]
k_cache_cpu_transposed = torch.zeros(
num_blocks, num_layers, block_size, num_kv_heads, head_dim,
dtype=torch.float16, device="cpu", pin_memory=True
)
print(f" k_cache_cpu_transposed[0].is_pinned(): {k_cache_cpu_transposed[0].is_pinned()}")
print(f" k_cache_cpu_transposed[0].is_contiguous(): {k_cache_cpu_transposed[0].is_contiguous()}")
torch.cuda.nvtx.range_push("Test2_Contiguous")
times = []
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"D2H_Contig_{i}")
start = time.perf_counter()
k_cache_cpu_transposed[i % num_blocks].copy_(k_cache_gpu[:, 0], non_blocking=True)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
avg_time = sum(times) / len(times)
print(f"Average time: {avg_time * 1000:.3f} ms")
print(f"Bandwidth: {k_cache_gpu[:, 0].numel() * 2 / avg_time / 1e9:.2f} GB/s")
# Test 3: Fully contiguous buffer
print(f"\n" + "=" * 60)
print("Test 3: Fully contiguous buffer")
print("=" * 60)
k_cache_cpu_flat = torch.zeros(
num_layers * block_size * num_kv_heads * head_dim,
dtype=torch.float16, device="cpu", pin_memory=True
)
print(f" k_cache_cpu_flat.is_pinned(): {k_cache_cpu_flat.is_pinned()}")
print(f" k_cache_cpu_flat.is_contiguous(): {k_cache_cpu_flat.is_contiguous()}")
torch.cuda.nvtx.range_push("Test3_FlatContiguous")
times = []
for i in range(NUM_ITERS):
torch.cuda.nvtx.range_push(f"D2H_Flat_{i}")
start = time.perf_counter()
k_cache_cpu_flat.copy_(k_cache_gpu[:, 0].flatten(), non_blocking=True)
torch.cuda.synchronize()
times.append(time.perf_counter() - start)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop()
avg_time = sum(times) / len(times)
print(f"Average time: {avg_time * 1000:.3f} ms")
print(f"Bandwidth: {k_cache_cpu_flat.numel() * 2 / avg_time / 1e9:.2f} GB/s")
print("\n" + "=" * 60)
print("test_pinned_transfer: PASSED")
print("=" * 60)