[WIP] Added sgDMA operator for scatter kvcache communication.

This commit is contained in:
Zijie Tian
2025-12-24 23:48:52 +08:00
parent 6ec1b23982
commit cf5e7df093
9 changed files with 1061 additions and 1 deletions

View File

@@ -0,0 +1,23 @@
cmake_minimum_required(VERSION 3.18)
project(sgdma_test CUDA CXX)
# Find CUDA
enable_language(CUDA)
find_package(CUDA REQUIRED)
# Set C++ standard
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
# CUDA flags
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3 --use_fast_math")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
# Build test executable
add_executable(sgdma_test sgdma_test.cpp)
target_link_libraries(sgdma_test cudart)
# Set output directory
set_target_properties(sgdma_test PROPERTIES
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin
)

View File

@@ -0,0 +1,326 @@
#include <cuda_runtime.h>
#include <iostream>
#include <chrono>
#include <cstring>
#include <cstdlib>
#include <iomanip>
// CUDA error checking macro
#define CUDA_CHECK(call) do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
std::cerr << "CUDA Error in " << __FILE__ << " at line " << __LINE__ << ": " \
<< cudaGetErrorString(err) << std::endl; \
exit(EXIT_FAILURE); \
} \
} while (0)
// Configuration matching nano-vllm realistic parameters
struct Config {
int num_layers = 32;
int num_blocks = 10; // Reduced from 100 to avoid huge allocation
int block_size = 4096;
int num_kv_heads = 8;
int head_dim = 128;
int dtype_size = 2; // float16
// Derived parameters (use size_t to avoid overflow)
size_t features_per_block() const { return (size_t)block_size * num_kv_heads * head_dim; }
size_t bytes_per_block() const { return features_per_block() * dtype_size; }
int total_blocks_per_layer() const { return num_blocks; }
size_t bytes_per_layer() const { return (size_t)num_blocks * bytes_per_block(); }
size_t total_bytes() const { return (size_t)num_layers * bytes_per_layer(); }
};
// Timer utility
class Timer {
std::chrono::high_resolution_clock::time_point start_time;
public:
void start() { start_time = std::chrono::high_resolution_clock::now(); }
double elapsed_ms() {
auto end = std::chrono::high_resolution_clock::now();
return std::chrono::duration<double, std::milli>(end - start_time).count();
}
};
// Initialize CPU memory with test pattern
void init_test_data(void* data, size_t bytes, int seed) {
uint16_t* ptr = static_cast<uint16_t*>(data);
size_t num_elements = bytes / sizeof(uint16_t);
for (size_t i = 0; i < num_elements; i++) {
ptr[i] = static_cast<uint16_t>((seed + i) % 65536);
}
}
// Verify data correctness
bool verify_data(const void* data1, const void* data2, size_t bytes) {
const uint16_t* p1 = static_cast<const uint16_t*>(data1);
const uint16_t* p2 = static_cast<const uint16_t*>(data2);
size_t num_elements = bytes / sizeof(uint16_t);
for (size_t i = 0; i < num_elements; i++) {
if (p1[i] != p2[i]) {
std::cerr << "Mismatch at element " << i << ": "
<< p1[i] << " != " << p2[i] << std::endl;
return false;
}
}
return true;
}
// ============================================================
// Test 1: Basic Functionality Test
// ============================================================
bool test_basic_functionality(const Config& cfg) {
std::cout << "\n[Test 1] Basic Functionality Test" << std::endl;
std::cout << " Testing cudaMemcpy2D correctness with strided layout" << std::endl;
// Allocate strided CPU memory (pinned)
// Layout: [num_layers, num_blocks, block_features]
size_t total_bytes = cfg.total_bytes();
std::cout << " Allocating " << total_bytes / 1024.0 / 1024.0 / 1024.0 << " GB pinned memory..." << std::endl;
void* cpu_strided = nullptr;
CUDA_CHECK(cudaMallocHost(&cpu_strided, total_bytes));
std::cout << " CPU strided memory allocated at: " << cpu_strided << std::endl;
// Allocate GPU memory for one block (all layers)
size_t gpu_block_bytes = cfg.num_layers * cfg.bytes_per_block();
void* gpu_data = nullptr;
CUDA_CHECK(cudaMalloc(&gpu_data, gpu_block_bytes));
// Allocate CPU verify buffer
void* cpu_verify = nullptr;
CUDA_CHECK(cudaMallocHost(&cpu_verify, gpu_block_bytes));
// Initialize strided CPU memory
init_test_data(cpu_strided, total_bytes, 12345);
// Test: Copy block_id=5 from CPU to GPU using cudaMemcpy2D
int test_block_id = 5;
size_t spitch = cfg.bytes_per_layer(); // Source pitch (stride between layers)
size_t dpitch = cfg.bytes_per_block(); // Destination pitch (contiguous)
size_t width = cfg.bytes_per_block(); // Width to copy per row
size_t height = cfg.num_layers; // Number of rows (layers)
// Debug: print parameters
std::cout << " cudaMemcpy2D parameters:" << std::endl;
std::cout << " spitch: " << spitch << " bytes" << std::endl;
std::cout << " dpitch: " << dpitch << " bytes" << std::endl;
std::cout << " width: " << width << " bytes" << std::endl;
std::cout << " height: " << height << " rows" << std::endl;
std::cout << " dpitch >= width: " << (dpitch >= width ? "yes" : "no") << std::endl;
std::cout << " spitch >= width: " << (spitch >= width ? "yes" : "no") << std::endl;
// Calculate source pointer (first layer, block_id)
uint8_t* src_ptr = static_cast<uint8_t*>(cpu_strided) + test_block_id * cfg.bytes_per_block();
// H2D transfer
CUDA_CHECK(cudaMemcpy2D(
gpu_data, // dst
dpitch, // dpitch
src_ptr, // src
spitch, // spitch
width, // width
height, // height
cudaMemcpyHostToDevice
));
// D2H transfer back
CUDA_CHECK(cudaMemcpy2D(
cpu_verify, // dst
dpitch, // dpitch
gpu_data, // src
dpitch, // spitch
width, // width
height, // height
cudaMemcpyDeviceToHost
));
// Verify correctness
bool passed = true;
for (int layer = 0; layer < cfg.num_layers; layer++) {
uint8_t* expected_ptr = static_cast<uint8_t*>(cpu_strided) +
layer * cfg.bytes_per_layer() +
test_block_id * cfg.bytes_per_block();
uint8_t* actual_ptr = static_cast<uint8_t*>(cpu_verify) +
layer * cfg.bytes_per_block();
if (!verify_data(expected_ptr, actual_ptr, cfg.bytes_per_block())) {
std::cerr << " Verification failed at layer " << layer << std::endl;
passed = false;
break;
}
}
// Cleanup
CUDA_CHECK(cudaFreeHost(cpu_strided));
CUDA_CHECK(cudaFreeHost(cpu_verify));
CUDA_CHECK(cudaFree(gpu_data));
std::cout << " Result: " << (passed ? "PASSED ✓" : "FAILED ✗") << std::endl;
return passed;
}
// ============================================================
// Test 2: Performance Benchmark
// ============================================================
void test_performance_benchmark(const Config& cfg) {
std::cout << "\n[Test 2] Performance Benchmark" << std::endl;
std::cout << " Configuration:" << std::endl;
std::cout << " num_layers: " << cfg.num_layers << std::endl;
std::cout << " num_blocks: " << cfg.num_blocks << std::endl;
std::cout << " block_size: " << cfg.block_size << std::endl;
std::cout << " num_kv_heads: " << cfg.num_kv_heads << std::endl;
std::cout << " head_dim: " << cfg.head_dim << std::endl;
std::cout << " dtype_size: " << cfg.dtype_size << " bytes" << std::endl;
std::cout << " bytes_per_block: " << cfg.bytes_per_block() / 1024.0 << " KB" << std::endl;
std::cout << " total transfer size: " << cfg.num_layers * cfg.bytes_per_block() / 1024.0 / 1024.0 << " MB" << std::endl;
const int num_iterations = 100;
const int warmup = 10;
int test_block_id = 5;
// Allocate memory
size_t total_bytes = cfg.total_bytes();
void* cpu_strided = nullptr;
CUDA_CHECK(cudaMallocHost(&cpu_strided, total_bytes));
void* cpu_contiguous = nullptr;
size_t gpu_block_bytes = cfg.num_layers * cfg.bytes_per_block();
CUDA_CHECK(cudaMallocHost(&cpu_contiguous, gpu_block_bytes));
void* gpu_data = nullptr;
CUDA_CHECK(cudaMalloc(&gpu_data, gpu_block_bytes));
init_test_data(cpu_strided, total_bytes, 12345);
init_test_data(cpu_contiguous, gpu_block_bytes, 12345);
Timer timer;
double elapsed;
double bandwidth;
// ========================================
// Method A: cudaMemcpy2D with strided layout
// ========================================
size_t spitch = cfg.bytes_per_layer();
size_t dpitch = cfg.bytes_per_block();
size_t width = cfg.bytes_per_block();
size_t height = cfg.num_layers;
uint8_t* src_ptr = static_cast<uint8_t*>(cpu_strided) + test_block_id * cfg.bytes_per_block();
// Warmup
for (int i = 0; i < warmup; i++) {
CUDA_CHECK(cudaMemcpy2D(gpu_data, dpitch, src_ptr, spitch, width, height, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
// Benchmark
timer.start();
for (int i = 0; i < num_iterations; i++) {
CUDA_CHECK(cudaMemcpy2D(gpu_data, dpitch, src_ptr, spitch, width, height, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
elapsed = timer.elapsed_ms();
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
std::cout << "\n Method A (cudaMemcpy2D strided):" << std::endl;
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
double method_a_bw = bandwidth;
// ========================================
// Method B: cudaMemcpy with contiguous layout (baseline)
// ========================================
// Warmup
for (int i = 0; i < warmup; i++) {
CUDA_CHECK(cudaMemcpy(gpu_data, cpu_contiguous, gpu_block_bytes, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
// Benchmark
timer.start();
for (int i = 0; i < num_iterations; i++) {
CUDA_CHECK(cudaMemcpy(gpu_data, cpu_contiguous, gpu_block_bytes, cudaMemcpyHostToDevice));
}
CUDA_CHECK(cudaDeviceSynchronize());
elapsed = timer.elapsed_ms();
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
std::cout << "\n Method B (cudaMemcpy contiguous):" << std::endl;
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
double method_b_bw = bandwidth;
// ========================================
// Method C: Layer-by-layer copy (simulate PyTorch non-contiguous)
// ========================================
// Warmup
for (int i = 0; i < warmup; i++) {
for (int layer = 0; layer < cfg.num_layers; layer++) {
uint8_t* src_layer = static_cast<uint8_t*>(cpu_strided) +
layer * cfg.bytes_per_layer() +
test_block_id * cfg.bytes_per_block();
uint8_t* dst_layer = static_cast<uint8_t*>(gpu_data) + layer * cfg.bytes_per_block();
CUDA_CHECK(cudaMemcpy(dst_layer, src_layer, cfg.bytes_per_block(), cudaMemcpyHostToDevice));
}
}
CUDA_CHECK(cudaDeviceSynchronize());
// Benchmark
timer.start();
for (int i = 0; i < num_iterations; i++) {
for (int layer = 0; layer < cfg.num_layers; layer++) {
uint8_t* src_layer = static_cast<uint8_t*>(cpu_strided) +
layer * cfg.bytes_per_layer() +
test_block_id * cfg.bytes_per_block();
uint8_t* dst_layer = static_cast<uint8_t*>(gpu_data) + layer * cfg.bytes_per_block();
CUDA_CHECK(cudaMemcpy(dst_layer, src_layer, cfg.bytes_per_block(), cudaMemcpyHostToDevice));
}
}
CUDA_CHECK(cudaDeviceSynchronize());
elapsed = timer.elapsed_ms();
bandwidth = (gpu_block_bytes * num_iterations / 1e9) / (elapsed / 1000.0);
std::cout << "\n Method C (layer-by-layer copy):" << std::endl;
std::cout << " Avg time: " << std::fixed << std::setprecision(3) << elapsed / num_iterations << " ms" << std::endl;
std::cout << " Bandwidth: " << std::setprecision(2) << bandwidth << " GB/s" << std::endl;
double method_c_bw = bandwidth;
// Summary
std::cout << "\n ========================================" << std::endl;
std::cout << " Performance Summary:" << std::endl;
std::cout << " Method A vs Method B: " << std::setprecision(2) << (method_a_bw / method_b_bw * 100) << "%" << std::endl;
std::cout << " Method A vs Method C: " << std::setprecision(2) << (method_a_bw / method_c_bw) << "x speedup" << std::endl;
std::cout << " ========================================" << std::endl;
// Cleanup
CUDA_CHECK(cudaFreeHost(cpu_strided));
CUDA_CHECK(cudaFreeHost(cpu_contiguous));
CUDA_CHECK(cudaFree(gpu_data));
}
int main() {
std::cout << "=== cudaMemcpy2D Test ===" << std::endl;
// Print CUDA device info
int device;
CUDA_CHECK(cudaGetDevice(&device));
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, device));
std::cout << "Using GPU: " << prop.name << std::endl;
std::cout << "Memory Clock Rate: " << prop.memoryClockRate / 1000 << " MHz" << std::endl;
std::cout << "Memory Bus Width: " << prop.memoryBusWidth << " bits" << std::endl;
std::cout << "Peak Memory Bandwidth: " <<
2.0 * prop.memoryClockRate * (prop.memoryBusWidth / 8) / 1.0e6 << " GB/s" << std::endl;
Config cfg;
// Run tests
bool test1_passed = test_basic_functionality(cfg);
test_performance_benchmark(cfg);
std::cout << "\n=== Test Complete ===" << std::endl;
std::cout << "All tests " << (test1_passed ? "PASSED ✓" : "FAILED ✗") << std::endl;
return test1_passed ? 0 : 1;
}

230
tests/test_sgdma.py Normal file
View File

@@ -0,0 +1,230 @@
"""
Tests for CUDA sgDMA (cudaMemcpy2D) extension.
Author: Zijie Tian
"""
import torch
import time
from nanovllm.comm import memcpy_2d, memcpy_2d_async
# ============================================================
# Configuration
# ============================================================
class Config:
num_layers = 32
num_blocks = 10
block_size = 4096
num_kv_heads = 8
head_dim = 128
dtype = torch.float16
@property
def features_per_block(self):
return self.block_size * self.num_kv_heads * self.head_dim
@property
def bytes_per_block(self):
return self.features_per_block * self.dtype.itemsize
@property
def bytes_per_layer(self):
return self.num_blocks * self.bytes_per_block
# ============================================================
# Test 1: Async Transfer
# ============================================================
def test_async_transfer():
"""Test asynchronous transfer with CUDA stream."""
print("\n[Test 1] Async Transfer Test")
cfg = Config()
# Create test data
cpu_data = torch.randn(
cfg.num_layers,
cfg.num_blocks,
cfg.features_per_block,
dtype=cfg.dtype,
pin_memory=True
)
gpu_buffer = torch.empty(
cfg.num_layers,
cfg.features_per_block,
dtype=cfg.dtype,
device='cuda'
)
# Create CUDA stream
stream = torch.cuda.Stream()
test_block_id = 5
spitch = cfg.bytes_per_layer
dpitch = cfg.bytes_per_block
width = cfg.bytes_per_block
height = cfg.num_layers
# Async transfer
with torch.cuda.stream(stream):
src_view = cpu_data[:, test_block_id, :]
memcpy_2d_async(gpu_buffer, src_view, dpitch, spitch, width, height, "h2d", stream)
# Wait for completion
stream.synchronize()
# Verify
expected = cpu_data[:, test_block_id, :].cuda()
if torch.allclose(gpu_buffer, expected, rtol=1e-3, atol=1e-3):
print(" Result: PASSED ✓")
return True
else:
print(" Result: FAILED ✗")
return False
# ============================================================
# Test 2: Performance Benchmark
# ============================================================
def benchmark_sgdma():
"""Benchmark cudaMemcpy2D vs standard PyTorch methods."""
print("\n[Test 2] Performance Benchmark")
cfg = Config()
print(f" Configuration:")
print(f" num_layers: {cfg.num_layers}")
print(f" num_blocks: {cfg.num_blocks}")
print(f" block_size: {cfg.block_size}")
print(f" dtype: {cfg.dtype}")
print(f" bytes_per_block: {cfg.bytes_per_block / 1024:.1f} KB")
print(f" total transfer size: {cfg.num_layers * cfg.bytes_per_block / 1024 / 1024:.1f} MB")
num_iterations = 10
warmup = 3
test_block_id = 5
# Allocate memory
cpu_strided = torch.randn(
cfg.num_layers,
cfg.num_blocks,
cfg.features_per_block,
dtype=cfg.dtype,
pin_memory=True
)
# ========================================
# Method A: cudaMemcpy2D with sgDMA
# ========================================
gpu_buffer_a = torch.empty(cfg.num_layers, cfg.features_per_block, dtype=cfg.dtype, device='cuda')
spitch = cfg.bytes_per_layer
dpitch = cfg.bytes_per_block
width = cfg.bytes_per_block
height = cfg.num_layers
src_view = cpu_strided[:, test_block_id, :]
# Warmup
for _ in range(warmup):
memcpy_2d(gpu_buffer_a, src_view, dpitch, spitch, width, height, "h2d")
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iterations):
memcpy_2d(gpu_buffer_a, src_view, dpitch, spitch, width, height, "h2d")
torch.cuda.synchronize()
elapsed_a = time.perf_counter() - start
avg_time_a = elapsed_a / num_iterations * 1000 # ms
bandwidth_a = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_a
print(f"\n Method A (cudaMemcpy2D sgDMA):")
print(f" Avg time: {avg_time_a:.3f} ms")
print(f" Bandwidth: {bandwidth_a:.2f} GB/s")
# ========================================
# Method B: PyTorch .cuda() on strided view
# ========================================
# Warmup
for _ in range(warmup):
_ = cpu_strided[:, test_block_id, :].cuda()
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iterations):
_ = cpu_strided[:, test_block_id, :].cuda()
torch.cuda.synchronize()
elapsed_b = time.perf_counter() - start
avg_time_b = elapsed_b / num_iterations * 1000 # ms
bandwidth_b = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_b
print(f"\n Method B (PyTorch .cuda() on strided):")
print(f" Avg time: {avg_time_b:.3f} ms")
print(f" Bandwidth: {bandwidth_b:.2f} GB/s")
# ========================================
# Method C: PyTorch .cuda() on contiguous (pinned)
# ========================================
# Create contiguous version with pinned memory
cpu_contiguous = torch.empty(
cfg.num_layers,
cfg.features_per_block,
dtype=cfg.dtype,
pin_memory=True
)
cpu_contiguous.copy_(cpu_strided[:, test_block_id, :])
# Warmup
for _ in range(warmup):
_ = cpu_contiguous.cuda()
torch.cuda.synchronize()
# Benchmark
start = time.perf_counter()
for _ in range(num_iterations):
_ = cpu_contiguous.cuda()
torch.cuda.synchronize()
elapsed_c = time.perf_counter() - start
avg_time_c = elapsed_c / num_iterations * 1000 # ms
bandwidth_c = (cfg.num_layers * cfg.bytes_per_block * num_iterations / 1e9) / elapsed_c
print(f"\n Method C (PyTorch .cuda() on contiguous):")
print(f" Avg time: {avg_time_c:.3f} ms")
print(f" Bandwidth: {bandwidth_c:.2f} GB/s")
# Summary
print(f"\n ========================================")
print(f" Performance Summary:")
print(f" Method A vs Method B: {bandwidth_a / bandwidth_b:.2f}x speedup")
print(f" Method A vs Method C: {bandwidth_a / bandwidth_c * 100:.2f}%")
print(f" ========================================")
# ============================================================
# Main
# ============================================================
if __name__ == "__main__":
print("=== CUDA sgDMA (cudaMemcpy2D) Tests ===")
# Check CUDA availability
if not torch.cuda.is_available():
print("CUDA not available. Skipping tests.")
exit(1)
# Print GPU info
print(f"Using GPU: {torch.cuda.get_device_name()}")
# Run tests
test1_passed = test_async_transfer()
benchmark_sgdma()
print("\n=== Tests Complete ===")
print(f"All tests {'PASSED ✓' if test1_passed else 'FAILED ✗'}")