[WIP] replace merge attention with triton kernel.
This commit is contained in:
@@ -275,6 +275,85 @@ def flash_attn_with_lse(
|
||||
return out, lse
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _merge_lse_kernel(
|
||||
lse1_ptr, lse2_ptr, lse_out_ptr,
|
||||
num_elements: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Fused kernel for merging LSE values."""
|
||||
# Each program handles BLOCK_SIZE elements
|
||||
pid = tl.program_id(0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < num_elements
|
||||
|
||||
# Load lse values
|
||||
lse1 = tl.load(lse1_ptr + offsets, mask=mask)
|
||||
lse2 = tl.load(lse2_ptr + offsets, mask=mask)
|
||||
|
||||
# Compute max for numerical stability
|
||||
max_lse = tl.maximum(lse1, lse2)
|
||||
|
||||
# Compute exp(lse - max_lse)
|
||||
exp1 = tl.exp(lse1 - max_lse)
|
||||
exp2 = tl.exp(lse2 - max_lse)
|
||||
|
||||
# Compute merged LSE: max_lse + log(exp1 + exp2)
|
||||
lse_merged = max_lse + tl.log(exp1 + exp2)
|
||||
|
||||
# Store result
|
||||
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _merge_output_kernel(
|
||||
o1_ptr, o2_ptr, lse1_ptr, lse2_ptr, o_out_ptr,
|
||||
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Fused kernel for merging attention outputs."""
|
||||
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
|
||||
pid_batch = tl.program_id(0)
|
||||
pid_seq = tl.program_id(1)
|
||||
pid_head = tl.program_id(2)
|
||||
|
||||
# Compute LSE index: [batch, nheads, seqlen_q]
|
||||
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
|
||||
|
||||
# Load LSE values
|
||||
lse1 = tl.load(lse1_ptr + lse_idx)
|
||||
lse2 = tl.load(lse2_ptr + lse_idx)
|
||||
|
||||
# Compute max and scaling factors
|
||||
max_lse = tl.maximum(lse1, lse2)
|
||||
exp1 = tl.exp(lse1 - max_lse)
|
||||
exp2 = tl.exp(lse2 - max_lse)
|
||||
sum_exp = exp1 + exp2
|
||||
|
||||
# Process headdim in chunks
|
||||
for d_offset in range(0, headdim, BLOCK_SIZE):
|
||||
d_idx = d_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = d_idx < headdim
|
||||
|
||||
# Compute output index: [batch, seqlen_q, nheads, headdim]
|
||||
base_idx = (pid_batch * seqlen_q * nheads * headdim +
|
||||
pid_seq * nheads * headdim +
|
||||
pid_head * headdim)
|
||||
o_idx = base_idx + d_idx
|
||||
|
||||
# Load o1, o2
|
||||
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0)
|
||||
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0)
|
||||
|
||||
# Compute merged output: (o1 * exp1 + o2 * exp2) / sum_exp
|
||||
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
|
||||
|
||||
# Store result
|
||||
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
|
||||
|
||||
|
||||
def merge_attention_outputs(
|
||||
o1: torch.Tensor,
|
||||
lse1: torch.Tensor,
|
||||
@@ -282,7 +361,7 @@ def merge_attention_outputs(
|
||||
lse2: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Merge two attention outputs using online softmax.
|
||||
Merge two attention outputs using online softmax (Triton fused kernel).
|
||||
|
||||
This implements the online softmax merging formula:
|
||||
- m_new = max(lse1, lse2)
|
||||
@@ -299,31 +378,30 @@ def merge_attention_outputs(
|
||||
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
|
||||
lse_merged: Merged LSE [batch, nheads, seqlen_q]
|
||||
"""
|
||||
# lse shape: [batch, nheads, seqlen_q]
|
||||
# o shape: [batch, seqlen_q, nheads, headdim]
|
||||
batch, seqlen_q, nheads, headdim = o1.shape
|
||||
|
||||
# Compute max for numerical stability
|
||||
max_lse = torch.maximum(lse1, lse2)
|
||||
# Allocate output tensors
|
||||
o_merged = torch.empty_like(o1)
|
||||
lse_merged = torch.empty_like(lse1)
|
||||
|
||||
# Compute scaling factors
|
||||
# exp1, exp2 shape: [batch, nheads, seqlen_q]
|
||||
exp1 = torch.exp(lse1 - max_lse)
|
||||
exp2 = torch.exp(lse2 - max_lse)
|
||||
# Launch LSE merge kernel
|
||||
num_lse_elements = batch * nheads * seqlen_q
|
||||
BLOCK_SIZE_LSE = 256
|
||||
grid_lse = (triton.cdiv(num_lse_elements, BLOCK_SIZE_LSE),)
|
||||
_merge_lse_kernel[grid_lse](
|
||||
lse1, lse2, lse_merged,
|
||||
num_lse_elements,
|
||||
BLOCK_SIZE=BLOCK_SIZE_LSE,
|
||||
)
|
||||
|
||||
# Reshape for broadcasting with output
|
||||
# [batch, nheads, seqlen_q] -> [batch, seqlen_q, nheads, 1]
|
||||
exp1_broad = exp1.transpose(1, 2).unsqueeze(-1)
|
||||
exp2_broad = exp2.transpose(1, 2).unsqueeze(-1)
|
||||
|
||||
# Merge outputs
|
||||
sum_exp = exp1_broad + exp2_broad
|
||||
o_merged = (o1 * exp1_broad + o2 * exp2_broad) / sum_exp
|
||||
|
||||
# Compute merged LSE
|
||||
lse_merged = max_lse + torch.log(exp1 + exp2)
|
||||
|
||||
# Ensure output has same dtype as input
|
||||
o_merged = o_merged.to(o1.dtype)
|
||||
# Launch output merge kernel
|
||||
BLOCK_SIZE = 128
|
||||
grid_output = (batch, seqlen_q, nheads)
|
||||
_merge_output_kernel[grid_output](
|
||||
o1, o2, lse1, lse2, o_merged,
|
||||
batch, seqlen_q, nheads, headdim,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return o_merged, lse_merged
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
from nanovllm.kvcache.kernels import gathered_copy_kv
|
||||
from nanovllm.comm import memcpy_2d_async
|
||||
from nanovllm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("offload_engine")
|
||||
@@ -65,6 +66,16 @@ class OffloadEngine:
|
||||
self.kv_dim = num_kv_heads * head_dim
|
||||
self.block_numel = block_size * self.kv_dim
|
||||
|
||||
# ========== sgDMA pitch parameters for strided transfers ==========
|
||||
self.dtype_size = dtype.itemsize
|
||||
self.cpu_pitch = num_cpu_blocks * self.block_numel * self.dtype_size
|
||||
self.gpu_pitch = num_gpu_blocks * self.block_numel * self.dtype_size
|
||||
self.width = self.block_numel * self.dtype_size
|
||||
self.height = num_layers
|
||||
|
||||
logger.info(f"sgDMA parameters: cpu_pitch={self.cpu_pitch}, gpu_pitch={self.gpu_pitch}, "
|
||||
f"width={self.width}, height={self.height}")
|
||||
|
||||
# ========== Unified Ring Buffer configuration ==========
|
||||
# Constraint checks
|
||||
assert num_gpu_blocks >= 2, \
|
||||
@@ -478,14 +489,18 @@ class OffloadEngine:
|
||||
|
||||
with torch.cuda.stream(stream):
|
||||
for cpu_block_id, gpu_slot in zip(cpu_block_ids, gpu_slot_ids):
|
||||
# Copy all layers at once
|
||||
self.k_cache_gpu[:, gpu_slot].copy_(
|
||||
# Copy all layers at once using sgDMA
|
||||
memcpy_2d_async(
|
||||
self.k_cache_gpu[:, gpu_slot],
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
non_blocking=True
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=stream
|
||||
)
|
||||
self.v_cache_gpu[:, gpu_slot].copy_(
|
||||
memcpy_2d_async(
|
||||
self.v_cache_gpu[:, gpu_slot],
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
non_blocking=True
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=stream
|
||||
)
|
||||
|
||||
stream.synchronize()
|
||||
@@ -697,11 +712,17 @@ class OffloadEngine:
|
||||
logger.debug(f"Ring load all layers: CPU[{cpu_block_id}] -> GPU slot[{slot_idx}]")
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.k_cache_gpu[:, slot_idx].copy_(
|
||||
self.k_cache_cpu[:, cpu_block_id], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.k_cache_gpu[:, slot_idx],
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=self.transfer_stream_main
|
||||
)
|
||||
self.v_cache_gpu[:, slot_idx].copy_(
|
||||
self.v_cache_cpu[:, cpu_block_id], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.v_cache_gpu[:, slot_idx],
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
self.gpu_pitch, self.cpu_pitch, self.width, self.height,
|
||||
"h2d", stream=self.transfer_stream_main
|
||||
)
|
||||
self.ring_slot_all_layers_ready[slot_idx].record(self.transfer_stream_main)
|
||||
|
||||
@@ -724,11 +745,17 @@ class OffloadEngine:
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, slot_idx], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.k_cache_gpu[:, slot_idx],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, slot_idx], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
self.v_cache_gpu[:, slot_idx],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
self.ring_slot_all_layers_offload_done[slot_idx].record(self.transfer_stream_main)
|
||||
torch.cuda.nvtx.range_pop()
|
||||
@@ -813,11 +840,17 @@ class OffloadEngine:
|
||||
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.k_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[:, self.decode_slot], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.k_cache_gpu[:, self.decode_slot],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
self.v_cache_cpu[:, cpu_block_id].copy_(
|
||||
self.v_cache_gpu[:, self.decode_slot], non_blocking=True
|
||||
memcpy_2d_async(
|
||||
self.v_cache_cpu[:, cpu_block_id],
|
||||
self.v_cache_gpu[:, self.decode_slot],
|
||||
self.cpu_pitch, self.gpu_pitch, self.width, self.height,
|
||||
"d2h", stream=self.transfer_stream_main
|
||||
)
|
||||
self.decode_offload_done.record(self.transfer_stream_main)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user