[WIP] replace merge attention with triton kernel.
This commit is contained in:
@@ -14,6 +14,7 @@ from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nanovllm.kvcache.kernels import gathered_copy_kv
|
||||
from nanovllm.comm import memcpy_2d_async
|
||||
from nanovllm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("offload_engine")
|
||||
@@ -65,6 +66,16 @@ class OffloadEngine:
|
||||
self.kv_dim = num_kv_heads * head_dim
|
||||
self.block_numel = block_size * self.kv_dim
|
||||
|
||||
# ========== sgDMA pitch parameters for strided transfers ==========
|
||||
self.dtype_size = dtype.itemsize
|
||||
self.cpu_pitch = num_cpu_blocks * self.block_numel * self.dtype_size
|
||||
self.gpu_pitch = num_gpu_blocks * self.block_numel * self.dtype_size
|
||||
self.width = self.block_numel * self.dtype_size
|
||||
self.height = num_layers
|
||||
|
||||
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, gpu_pitch={self.gpu_pitch}, "
|
||||
f"width={self.width}, height={self.height}")
|
||||
|
||||
# ========== Unified Ring Buffer configuration ==========
|
||||
# Constraint checks
|
||||
assert num_gpu_blocks >= 2, \
|
||||
@@ -478,14 +489,18 @@ class OffloadEngine:
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||
# Copy all layers at once
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
# Copy all layers at once using sgDMA
|
||||
memcpy_2d_async(
|
||||
self.k_cache_gpu[:, gpu_slot],
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
non_blocking=True
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=stream
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
memcpy_2d_async(
|
||||
self.v_cache_gpu[:, gpu_slot],
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
non_blocking=True
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=stream
|
||||
)
|
||||
|
||||
stream.synchronize()
|
||||
@@ -697,11 +712,17 @@ class OffloadEngine:
|
||||
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.k_cache_gpu[:, slot_idx].copy_(
|
||||
self.k_cache_cpu[:, cpu_block_id], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.k_cache_gpu[:, slot_idx],
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=self.transfer_stream_main
|
||||
)
|
||||
self.v_cache_gpu[:, slot_idx].copy_(
|
||||
self.v_cache_cpu[:, cpu_block_id], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.v_cache_gpu[:, slot_idx],
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=self.transfer_stream_main
|
||||
)
|
||||
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
|
||||
|
||||
@@ -724,11 +745,17 @@ class OffloadEngine:
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, slot_idx], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.k_cache_gpu[:, slot_idx],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, slot_idx], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
self.v_cache_gpu[:, slot_idx],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
@@ -813,11 +840,17 @@ class OffloadEngine:
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, self.decode_slot], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.k_cache_gpu[:, self.decode_slot],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, self.decode_slot], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
self.v_cache_gpu[:, self.decode_slot],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
self.decode_offload_done.record(self.transfer_stream_main)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user