191 lines
5.9 KiB
Python
191 lines
5.9 KiB
Python
"""
|
|
Triton kernels for CPU-GPU KV cache transfer.
|
|
|
|
These kernels are designed to be CUDA Graph compatible:
|
|
- All tensor addresses are fixed at graph capture time
|
|
- Only the content of index tensors changes between replays
|
|
"""
|
|
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def gathered_copy_kernel(
|
|
src_ptr, # Source tensor base pointer (CPU pinned or GPU)
|
|
dst_ptr, # Destination tensor base pointer (GPU)
|
|
indices_ptr, # Gather indices [num_dst_blocks]
|
|
num_dst_blocks, # Number of destination blocks
|
|
block_numel: tl.constexpr, # Elements per block (block_size * kv_heads * head_dim)
|
|
BLOCK_SIZE: tl.constexpr = 1024,
|
|
):
|
|
"""
|
|
Gathered copy kernel: dst[i] = src[indices[i]]
|
|
|
|
Each program instance handles one destination block.
|
|
The indices tensor specifies which source block to copy from.
|
|
|
|
This kernel is CUDA Graph compatible because:
|
|
- src_ptr, dst_ptr, indices_ptr addresses are fixed
|
|
- Only indices content changes between graph replays
|
|
|
|
Args:
|
|
src_ptr: Base pointer to source blocks [num_src_blocks, block_numel]
|
|
dst_ptr: Base pointer to destination blocks [num_dst_blocks, block_numel]
|
|
indices_ptr: Gather indices [num_dst_blocks], each value is a source block index
|
|
num_dst_blocks: Number of destination blocks to copy
|
|
block_numel: Number of elements per block
|
|
BLOCK_SIZE: Triton block size for parallelization
|
|
"""
|
|
dst_block_idx = tl.program_id(0)
|
|
|
|
# Skip if out of range
|
|
if dst_block_idx >= num_dst_blocks:
|
|
return
|
|
|
|
# Load source block index from indices tensor
|
|
src_block_idx = tl.load(indices_ptr + dst_block_idx)
|
|
|
|
# Skip if index is -1 (invalid/no-op marker)
|
|
if src_block_idx < 0:
|
|
return
|
|
|
|
# Calculate base offsets
|
|
src_base = src_block_idx * block_numel
|
|
dst_base = dst_block_idx * block_numel
|
|
|
|
# Copy block data in chunks of BLOCK_SIZE
|
|
for start in range(0, block_numel, BLOCK_SIZE):
|
|
offsets = start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < block_numel
|
|
|
|
# Load from source and store to destination
|
|
data = tl.load(src_ptr + src_base + offsets, mask=mask)
|
|
tl.store(dst_ptr + dst_base + offsets, data, mask=mask)
|
|
|
|
|
|
@triton.jit
|
|
def gathered_copy_kv_kernel(
|
|
k_src_ptr, # K cache source [num_src_blocks, block_size, kv_heads, head_dim]
|
|
v_src_ptr, # V cache source
|
|
k_dst_ptr, # K cache destination
|
|
v_dst_ptr, # V cache destination
|
|
indices_ptr, # Gather indices [num_dst_blocks]
|
|
num_dst_blocks, # Number of destination blocks
|
|
block_numel: tl.constexpr, # Elements per block
|
|
BLOCK_SIZE: tl.constexpr = 1024,
|
|
):
|
|
"""
|
|
Gathered copy for both K and V caches simultaneously.
|
|
|
|
More efficient than calling gathered_copy_kernel twice because:
|
|
- Single kernel launch overhead
|
|
- Better memory access patterns when K and V are accessed together
|
|
"""
|
|
dst_block_idx = tl.program_id(0)
|
|
|
|
if dst_block_idx >= num_dst_blocks:
|
|
return
|
|
|
|
src_block_idx = tl.load(indices_ptr + dst_block_idx)
|
|
|
|
if src_block_idx < 0:
|
|
return
|
|
|
|
src_base = src_block_idx * block_numel
|
|
dst_base = dst_block_idx * block_numel
|
|
|
|
for start in range(0, block_numel, BLOCK_SIZE):
|
|
offsets = start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < block_numel
|
|
|
|
# Copy K cache
|
|
k_data = tl.load(k_src_ptr + src_base + offsets, mask=mask)
|
|
tl.store(k_dst_ptr + dst_base + offsets, k_data, mask=mask)
|
|
|
|
# Copy V cache
|
|
v_data = tl.load(v_src_ptr + src_base + offsets, mask=mask)
|
|
tl.store(v_dst_ptr + dst_base + offsets, v_data, mask=mask)
|
|
|
|
|
|
def gathered_copy(
|
|
src: torch.Tensor,
|
|
dst: torch.Tensor,
|
|
indices: torch.Tensor,
|
|
) -> None:
|
|
"""
|
|
Perform gathered copy: dst[i] = src[indices[i]]
|
|
|
|
Args:
|
|
src: Source tensor [num_src_blocks, ...]
|
|
dst: Destination tensor [num_dst_blocks, ...]
|
|
indices: Index tensor [num_dst_blocks], dtype=int64
|
|
-1 means skip (no-op)
|
|
|
|
Note:
|
|
- src can be on CPU (pinned memory) or GPU
|
|
- dst must be on GPU
|
|
- indices must be on GPU
|
|
- All shapes after first dimension must match
|
|
"""
|
|
assert dst.is_cuda, "Destination must be on GPU"
|
|
assert indices.is_cuda, "Indices must be on GPU"
|
|
assert src.shape[1:] == dst.shape[1:], "Shape mismatch after first dimension"
|
|
|
|
num_dst_blocks = dst.shape[0]
|
|
block_numel = dst[0].numel()
|
|
|
|
# Flatten for kernel
|
|
src_flat = src.view(src.shape[0], -1)
|
|
dst_flat = dst.view(dst.shape[0], -1)
|
|
|
|
grid = (num_dst_blocks,)
|
|
gathered_copy_kernel[grid](
|
|
src_flat,
|
|
dst_flat,
|
|
indices,
|
|
num_dst_blocks,
|
|
block_numel=block_numel,
|
|
)
|
|
|
|
|
|
def gathered_copy_kv(
|
|
k_src: torch.Tensor,
|
|
v_src: torch.Tensor,
|
|
k_dst: torch.Tensor,
|
|
v_dst: torch.Tensor,
|
|
indices: torch.Tensor,
|
|
) -> None:
|
|
"""
|
|
Perform gathered copy for both K and V caches.
|
|
|
|
Args:
|
|
k_src, v_src: Source K/V caches [num_src_blocks, block_size, kv_heads, head_dim]
|
|
k_dst, v_dst: Destination K/V caches [num_dst_blocks, block_size, kv_heads, head_dim]
|
|
indices: Index tensor [num_dst_blocks], dtype=int64
|
|
"""
|
|
assert k_dst.is_cuda and v_dst.is_cuda, "Destinations must be on GPU"
|
|
assert indices.is_cuda, "Indices must be on GPU"
|
|
assert k_src.shape[1:] == k_dst.shape[1:], "K shape mismatch"
|
|
assert v_src.shape[1:] == v_dst.shape[1:], "V shape mismatch"
|
|
|
|
num_dst_blocks = k_dst.shape[0]
|
|
block_numel = k_dst[0].numel()
|
|
|
|
k_src_flat = k_src.view(k_src.shape[0], -1)
|
|
v_src_flat = v_src.view(v_src.shape[0], -1)
|
|
k_dst_flat = k_dst.view(k_dst.shape[0], -1)
|
|
v_dst_flat = v_dst.view(v_dst.shape[0], -1)
|
|
|
|
grid = (num_dst_blocks,)
|
|
gathered_copy_kv_kernel[grid](
|
|
k_src_flat,
|
|
v_src_flat,
|
|
k_dst_flat,
|
|
v_dst_flat,
|
|
indices,
|
|
num_dst_blocks,
|
|
block_numel=block_numel,
|
|
)
|