[WIP] fixing attention compute error.

This commit is contained in:
Zijie Tian
2025-12-30 00:31:48 +08:00
parent bf4c63c7ec
commit 89f8020d38
12 changed files with 2175 additions and 103 deletions

View File

@@ -281,7 +281,11 @@ def _merge_lse_kernel(
num_elements: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging LSE values."""
"""Fused kernel for merging LSE values.
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
"""
# Each program handles BLOCK_SIZE elements
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
@@ -289,21 +293,21 @@ def _merge_lse_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
# Load lse values
lse1 = tl.load(lse1_ptr + offsets, mask=mask)
lse2 = tl.load(lse2_ptr + offsets, mask=mask)
# Load lse values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
# Compute max for numerical stability
# Compute max for numerical stability (in fp32)
max_lse = tl.maximum(lse1, lse2)
# Compute exp(lse - max_lse)
# Compute exp(lse - max_lse) in fp32
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
# Compute merged LSE: max_lse + log(exp1 + exp2)
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
lse_merged = max_lse + tl.log(exp1 + exp2)
# Store result
# Store result (convert back to original dtype)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@@ -313,7 +317,11 @@ def _merge_output_kernel(
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging attention outputs."""
"""Fused kernel for merging attention outputs.
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
This is critical for numerical accuracy in chunked attention.
"""
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
pid_batch = tl.program_id(0)
pid_seq = tl.program_id(1)
@@ -322,11 +330,11 @@ def _merge_output_kernel(
# Compute LSE index: [batch, nheads, seqlen_q]
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
# Load LSE values
lse1 = tl.load(lse1_ptr + lse_idx)
lse2 = tl.load(lse2_ptr + lse_idx)
# Load LSE values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
# Compute max and scaling factors
# Compute max and scaling factors in fp32
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
@@ -343,14 +351,14 @@ def _merge_output_kernel(
pid_head * headdim)
o_idx = base_idx + d_idx
# Load o1, o2
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0)
# Load o1, o2 and convert to fp32 for weighted sum
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
# Compute merged output: (o1 * exp1 + o2 * exp2) / sum_exp
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
# Store result
# Store result (Triton will convert back to original dtype)
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)

View File

@@ -337,10 +337,10 @@ class HybridKVCacheManager(KVCacheManager):
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id)
logger.debug(
f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
f"returned cpu_blocks={cpu_blocks}"
)
# logger.debug(
# f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
# f"returned cpu_blocks={cpu_blocks}"
# )
return cpu_blocks
# ========== Ring Buffer CPU-primary support ==========

View File

@@ -538,7 +538,7 @@ class OffloadEngine:
def sync_indices(self) -> None:
"""Synchronize to ensure all index updates are complete."""
torch.cuda.current_stream().synchronize()
torch.cuda.default_stream().synchronize()
# ========== Cache access methods ==========
@@ -682,8 +682,9 @@ class OffloadEngine:
Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining.
Before starting the transfer, waits for any previous compute on this slot
to complete (using compute_done event).
Before starting the transfer, waits for:
1. Any previous compute on this slot to complete
2. Any pending offload of this slot to complete
Args:
slot_idx: Target GPU slot index
@@ -701,6 +702,10 @@ class OffloadEngine:
# This prevents data race: transfer must not start until attention finishes reading
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
# Also wait for any pending offload of this slot to complete
# This prevents race: load must not write GPU slot while offload is reading from it
stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
self.k_cache_gpu[layer_id, slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
@@ -763,7 +768,11 @@ class OffloadEngine:
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
# - compute_stream: for flash attention operations
# - default_stream: for store_kvcache which runs on default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
memcpy_2d_async(
self.k_cache_cpu[:, cpu_block_id],
self.k_cache_gpu[:, slot_idx],
@@ -793,7 +802,9 @@ class OffloadEngine:
cpu_block_id: Target CPU block ID
"""
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
)