[WIP] replace merge attention with triton kernel.
This commit is contained in:
@@ -6,7 +6,7 @@ Author: Zijie Tian
|
||||
|
||||
import torch
|
||||
import time
|
||||
from nanovllm.comm import memcpy_2d, memcpy_2d_async
|
||||
from nanovllm.comm import memcpy_2d
|
||||
|
||||
# ============================================================
|
||||
# Configuration
|
||||
@@ -34,64 +34,12 @@ class Config:
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Test 1: Async Transfer
|
||||
# ============================================================
|
||||
|
||||
def test_async_transfer():
|
||||
"""Test asynchronous transfer with CUDA stream."""
|
||||
print("\n[Test 1] Async Transfer Test")
|
||||
|
||||
cfg = Config()
|
||||
|
||||
# Create test data
|
||||
cpu_data = torch.randn(
|
||||
cfg.num_layers,
|
||||
cfg.num_blocks,
|
||||
cfg.features_per_block,
|
||||
dtype=cfg.dtype,
|
||||
pin_memory=True
|
||||
)
|
||||
gpu_buffer = torch.empty(
|
||||
cfg.num_layers,
|
||||
cfg.features_per_block,
|
||||
dtype=cfg.dtype,
|
||||
device='cuda'
|
||||
)
|
||||
|
||||
# Create CUDA stream
|
||||
stream = torch.cuda.Stream()
|
||||
|
||||
test_block_id = 5
|
||||
spitch = cfg.bytes_per_layer
|
||||
dpitch = cfg.bytes_per_block
|
||||
width = cfg.bytes_per_block
|
||||
height = cfg.num_layers
|
||||
|
||||
# Async transfer
|
||||
with torch.cuda.stream(stream):
|
||||
src_view = cpu_data[:, test_block_id, :]
|
||||
memcpy_2d_async(gpu_buffer, src_view, dpitch, spitch, width, height, "h2d", stream)
|
||||
|
||||
# Wait for completion
|
||||
stream.synchronize()
|
||||
|
||||
# Verify
|
||||
expected = cpu_data[:, test_block_id, :].cuda()
|
||||
if torch.allclose(gpu_buffer, expected, rtol=1e-3, atol=1e-3):
|
||||
print(" Result: PASSED ✓")
|
||||
return True
|
||||
else:
|
||||
print(" Result: FAILED ✗")
|
||||
return False
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Test 2: Performance Benchmark
|
||||
# Performance Benchmark
|
||||
# ============================================================
|
||||
|
||||
def benchmark_sgdma():
|
||||
"""Benchmark cudaMemcpy2D vs standard PyTorch methods."""
|
||||
print("\n[Test 2] Performance Benchmark")
|
||||
print("\n=== Performance Benchmark ===")
|
||||
|
||||
cfg = Config()
|
||||
|
||||
@@ -212,19 +160,17 @@ def benchmark_sgdma():
|
||||
# ============================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=== CUDA sgDMA (cudaMemcpy2D) Tests ===")
|
||||
print("=== CUDA sgDMA (cudaMemcpy2D) Benchmark ===")
|
||||
|
||||
# Check CUDA availability
|
||||
if not torch.cuda.is_available():
|
||||
print("CUDA not available. Skipping tests.")
|
||||
print("CUDA not available. Skipping benchmark.")
|
||||
exit(1)
|
||||
|
||||
# Print GPU info
|
||||
print(f"Using GPU: {torch.cuda.get_device_name()}")
|
||||
|
||||
# Run tests
|
||||
test1_passed = test_async_transfer()
|
||||
# Run benchmark
|
||||
benchmark_sgdma()
|
||||
|
||||
print("\n=== Tests Complete ===")
|
||||
print(f"All tests {'PASSED ✓' if test1_passed else 'FAILED ✗'}")
|
||||
print("\n=== Benchmark Complete ===")
|
||||
|
||||
Reference in New Issue
Block a user