[WIP] replace merge attention with triton kernel.

This commit is contained in:
Zijie Tian
2025-12-25 01:07:05 +08:00
parent cf5e7df093
commit 16fcf8350b
5 changed files with 490 additions and 405 deletions

View File

@@ -6,7 +6,7 @@ Author: Zijie Tian
import torch
import time
from nanovllm.comm import memcpy_2d, memcpy_2d_async
from nanovllm.comm import memcpy_2d
# ============================================================
# Configuration
@@ -34,64 +34,12 @@ class Config:
# ============================================================
# Test 1: Async Transfer
# ============================================================
def test_async_transfer():
"""Test asynchronous transfer with CUDA stream."""
print("\n[Test 1] Async Transfer Test")
cfg = Config()
# Create test data
cpu_data = torch.randn(
cfg.num_layers,
cfg.num_blocks,
cfg.features_per_block,
dtype=cfg.dtype,
pin_memory=True
)
gpu_buffer = torch.empty(
cfg.num_layers,
cfg.features_per_block,
dtype=cfg.dtype,
device='cuda'
)
# Create CUDA stream
stream = torch.cuda.Stream()
test_block_id = 5
spitch = cfg.bytes_per_layer
dpitch = cfg.bytes_per_block
width = cfg.bytes_per_block
height = cfg.num_layers
# Async transfer
with torch.cuda.stream(stream):
src_view = cpu_data[:, test_block_id, :]
memcpy_2d_async(gpu_buffer, src_view, dpitch, spitch, width, height, "h2d", stream)
# Wait for completion
stream.synchronize()
# Verify
expected = cpu_data[:, test_block_id, :].cuda()
if torch.allclose(gpu_buffer, expected, rtol=1e-3, atol=1e-3):
print(" Result: PASSED ✓")
return True
else:
print(" Result: FAILED ✗")
return False
# ============================================================
# Test 2: Performance Benchmark
# Performance Benchmark
# ============================================================
def benchmark_sgdma():
"""Benchmark cudaMemcpy2D vs standard PyTorch methods."""
print("\n[Test 2] Performance Benchmark")
print("\n=== Performance Benchmark ===")
cfg = Config()
@@ -212,19 +160,17 @@ def benchmark_sgdma():
# ============================================================
if __name__ == "__main__":
print("=== CUDA sgDMA (cudaMemcpy2D) Tests ===")
print("=== CUDA sgDMA (cudaMemcpy2D) Benchmark ===")
# Check CUDA availability
if not torch.cuda.is_available():
print("CUDA not available. Skipping tests.")
print("CUDA not available. Skipping benchmark.")
exit(1)
# Print GPU info
print(f"Using GPU: {torch.cuda.get_device_name()}")
# Run tests
test1_passed = test_async_transfer()
# Run benchmark
benchmark_sgdma()
print("\n=== Tests Complete ===")
print(f"All tests {'PASSED ✓' if test1_passed else 'FAILED ✗'}")
print("\n=== Benchmark Complete ===")