[WIP] fixing attention compute error.
This commit is contained in:
@@ -281,7 +281,11 @@ def _merge_lse_kernel(
|
||||
num_elements: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Fused kernel for merging LSE values."""
|
||||
"""Fused kernel for merging LSE values.
|
||||
|
||||
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
|
||||
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
|
||||
"""
|
||||
# Each program handles BLOCK_SIZE elements
|
||||
pid = tl.program_id(0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
@@ -289,21 +293,21 @@ def _merge_lse_kernel(
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < num_elements
|
||||
|
||||
# Load lse values
|
||||
lse1 = tl.load(lse1_ptr + offsets, mask=mask)
|
||||
lse2 = tl.load(lse2_ptr + offsets, mask=mask)
|
||||
# Load lse values and convert to fp32 for precision
|
||||
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
|
||||
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
|
||||
|
||||
# Compute max for numerical stability
|
||||
# Compute max for numerical stability (in fp32)
|
||||
max_lse = tl.maximum(lse1, lse2)
|
||||
|
||||
# Compute exp(lse - max_lse)
|
||||
# Compute exp(lse - max_lse) in fp32
|
||||
exp1 = tl.exp(lse1 - max_lse)
|
||||
exp2 = tl.exp(lse2 - max_lse)
|
||||
|
||||
# Compute merged LSE: max_lse + log(exp1 + exp2)
|
||||
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
|
||||
lse_merged = max_lse + tl.log(exp1 + exp2)
|
||||
|
||||
# Store result
|
||||
# Store result (convert back to original dtype)
|
||||
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
|
||||
|
||||
|
||||
@@ -313,7 +317,11 @@ def _merge_output_kernel(
|
||||
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
"""Fused kernel for merging attention outputs."""
|
||||
"""Fused kernel for merging attention outputs.
|
||||
|
||||
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
|
||||
This is critical for numerical accuracy in chunked attention.
|
||||
"""
|
||||
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
|
||||
pid_batch = tl.program_id(0)
|
||||
pid_seq = tl.program_id(1)
|
||||
@@ -322,11 +330,11 @@ def _merge_output_kernel(
|
||||
# Compute LSE index: [batch, nheads, seqlen_q]
|
||||
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
|
||||
|
||||
# Load LSE values
|
||||
lse1 = tl.load(lse1_ptr + lse_idx)
|
||||
lse2 = tl.load(lse2_ptr + lse_idx)
|
||||
# Load LSE values and convert to fp32 for precision
|
||||
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
|
||||
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
|
||||
|
||||
# Compute max and scaling factors
|
||||
# Compute max and scaling factors in fp32
|
||||
max_lse = tl.maximum(lse1, lse2)
|
||||
exp1 = tl.exp(lse1 - max_lse)
|
||||
exp2 = tl.exp(lse2 - max_lse)
|
||||
@@ -343,14 +351,14 @@ def _merge_output_kernel(
|
||||
pid_head * headdim)
|
||||
o_idx = base_idx + d_idx
|
||||
|
||||
# Load o1, o2
|
||||
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0)
|
||||
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0)
|
||||
# Load o1, o2 and convert to fp32 for weighted sum
|
||||
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
||||
|
||||
# Compute merged output: (o1 * exp1 + o2 * exp2) / sum_exp
|
||||
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
|
||||
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
|
||||
|
||||
# Store result
|
||||
# Store result (Triton will convert back to original dtype)
|
||||
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
|
||||
|
||||
|
||||
|
||||
@@ -337,10 +337,10 @@ class HybridKVCacheManager(KVCacheManager):
|
||||
block = self.logical_blocks[logical_id]
|
||||
if block.location == BlockLocation.CPU:
|
||||
cpu_blocks.append(block.cpu_block_id)
|
||||
logger.debug(
|
||||
f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
|
||||
f"returned cpu_blocks={cpu_blocks}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
|
||||
# f"returned cpu_blocks={cpu_blocks}"
|
||||
# )
|
||||
return cpu_blocks
|
||||
|
||||
# ========== Ring Buffer CPU-primary support ==========
|
||||
|
||||
@@ -538,7 +538,7 @@ class OffloadEngine:
|
||||
|
||||
def sync_indices(self) -> None:
|
||||
"""Synchronize to ensure all index updates are complete."""
|
||||
torch.cuda.current_stream().synchronize()
|
||||
torch.cuda.default_stream().synchronize()
|
||||
|
||||
# ========== Cache access methods ==========
|
||||
|
||||
@@ -682,8 +682,9 @@ class OffloadEngine:
|
||||
Async load a single CPU block to a ring buffer slot for one layer.
|
||||
|
||||
This is the core building block for ring buffer pipelining.
|
||||
Before starting the transfer, waits for any previous compute on this slot
|
||||
to complete (using compute_done event).
|
||||
Before starting the transfer, waits for:
|
||||
1. Any previous compute on this slot to complete
|
||||
2. Any pending offload of this slot to complete
|
||||
|
||||
Args:
|
||||
slot_idx: Target GPU slot index
|
||||
@@ -701,6 +702,10 @@ class OffloadEngine:
|
||||
# This prevents data race: transfer must not start until attention finishes reading
|
||||
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
|
||||
|
||||
# Also wait for any pending offload of this slot to complete
|
||||
# This prevents race: load must not write GPU slot while offload is reading from it
|
||||
stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
|
||||
|
||||
self.k_cache_gpu[layer_id, slot_idx].copy_(
|
||||
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
|
||||
)
|
||||
@@ -763,7 +768,11 @@ class OffloadEngine:
|
||||
|
||||
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for both compute_stream and default stream
|
||||
# - compute_stream: for flash attention operations
|
||||
# - default_stream: for store_kvcache which runs on default stream
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
|
||||
memcpy_2d_async(
|
||||
self.k_cache_cpu[:, cpu_block_id],
|
||||
self.k_cache_gpu[:, slot_idx],
|
||||
@@ -793,7 +802,9 @@ class OffloadEngine:
|
||||
cpu_block_id: Target CPU block ID
|
||||
"""
|
||||
with torch.cuda.stream(self.transfer_stream_main):
|
||||
# Wait for both compute_stream and default stream
|
||||
self.transfer_stream_main.wait_stream(self.compute_stream)
|
||||
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
|
||||
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user