[WIP] fixing attention compute error.

This commit is contained in:
Zijie Tian
2025-12-30 00:31:48 +08:00
parent bf4c63c7ec
commit 89f8020d38
12 changed files with 2175 additions and 103 deletions

View File

@@ -31,6 +31,8 @@ class LLMEngine:
self.model_runner = ModelRunner(config, 0, self.events)
self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True)
config.eos = self.tokenizer.eos_token_id
# Set Sequence.block_size to match the KV cache block size
Sequence.block_size = config.kvcache_block_size
self.scheduler = Scheduler(config, self.model_runner.kvcache_manager)
atexit.register(self.exit)

View File

@@ -521,6 +521,7 @@ class ModelRunner:
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
# Sample from last logits
# For chunked prefill, ParallelLMHead automatically selects last position's logits
temperatures = self.prepare_sample(seqs) if self.rank == 0 else None
if logits is not None:
token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None

View File

@@ -281,7 +281,11 @@ def _merge_lse_kernel(
num_elements: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging LSE values."""
"""Fused kernel for merging LSE values.
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
"""
# Each program handles BLOCK_SIZE elements
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
@@ -289,21 +293,21 @@ def _merge_lse_kernel(
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
# Load lse values
lse1 = tl.load(lse1_ptr + offsets, mask=mask)
lse2 = tl.load(lse2_ptr + offsets, mask=mask)
# Load lse values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
# Compute max for numerical stability
# Compute max for numerical stability (in fp32)
max_lse = tl.maximum(lse1, lse2)
# Compute exp(lse - max_lse)
# Compute exp(lse - max_lse) in fp32
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
# Compute merged LSE: max_lse + log(exp1 + exp2)
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
lse_merged = max_lse + tl.log(exp1 + exp2)
# Store result
# Store result (convert back to original dtype)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@@ -313,7 +317,11 @@ def _merge_output_kernel(
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging attention outputs."""
"""Fused kernel for merging attention outputs.
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
This is critical for numerical accuracy in chunked attention.
"""
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
pid_batch = tl.program_id(0)
pid_seq = tl.program_id(1)
@@ -322,11 +330,11 @@ def _merge_output_kernel(
# Compute LSE index: [batch, nheads, seqlen_q]
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
# Load LSE values
lse1 = tl.load(lse1_ptr + lse_idx)
lse2 = tl.load(lse2_ptr + lse_idx)
# Load LSE values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
# Compute max and scaling factors
# Compute max and scaling factors in fp32
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
@@ -343,14 +351,14 @@ def _merge_output_kernel(
pid_head * headdim)
o_idx = base_idx + d_idx
# Load o1, o2
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0)
# Load o1, o2 and convert to fp32 for weighted sum
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
# Compute merged output: (o1 * exp1 + o2 * exp2) / sum_exp
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
# Store result
# Store result (Triton will convert back to original dtype)
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)

View File

@@ -337,10 +337,10 @@ class HybridKVCacheManager(KVCacheManager):
block = self.logical_blocks[logical_id]
if block.location == BlockLocation.CPU:
cpu_blocks.append(block.cpu_block_id)
logger.debug(
f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
f"returned cpu_blocks={cpu_blocks}"
)
# logger.debug(
# f"get_prefilled_cpu_blocks: prefilled_blocks={list(self.prefilled_blocks)}, "
# f"returned cpu_blocks={cpu_blocks}"
# )
return cpu_blocks
# ========== Ring Buffer CPU-primary support ==========

View File

@@ -538,7 +538,7 @@ class OffloadEngine:
def sync_indices(self) -> None:
"""Synchronize to ensure all index updates are complete."""
torch.cuda.current_stream().synchronize()
torch.cuda.default_stream().synchronize()
# ========== Cache access methods ==========
@@ -682,8 +682,9 @@ class OffloadEngine:
Async load a single CPU block to a ring buffer slot for one layer.
This is the core building block for ring buffer pipelining.
Before starting the transfer, waits for any previous compute on this slot
to complete (using compute_done event).
Before starting the transfer, waits for:
1. Any previous compute on this slot to complete
2. Any pending offload of this slot to complete
Args:
slot_idx: Target GPU slot index
@@ -701,6 +702,10 @@ class OffloadEngine:
# This prevents data race: transfer must not start until attention finishes reading
stream.wait_event(self.ring_slot_compute_done[slot_idx][layer_id])
# Also wait for any pending offload of this slot to complete
# This prevents race: load must not write GPU slot while offload is reading from it
stream.wait_event(self.ring_slot_all_layers_offload_done[slot_idx])
self.k_cache_gpu[layer_id, slot_idx].copy_(
self.k_cache_cpu[layer_id, cpu_block_id], non_blocking=True
)
@@ -763,7 +768,11 @@ class OffloadEngine:
torch.cuda.nvtx.range_push(f"D2H: Slot[{slot_idx}]->CPU[{cpu_block_id}]")
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
# - compute_stream: for flash attention operations
# - default_stream: for store_kvcache which runs on default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
memcpy_2d_async(
self.k_cache_cpu[:, cpu_block_id],
self.k_cache_gpu[:, slot_idx],
@@ -793,7 +802,9 @@ class OffloadEngine:
cpu_block_id: Target CPU block ID
"""
with torch.cuda.stream(self.transfer_stream_main):
# Wait for both compute_stream and default stream
self.transfer_stream_main.wait_stream(self.compute_stream)
self.transfer_stream_main.wait_stream(torch.cuda.default_stream())
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
self.k_cache_gpu[layer_id, slot_idx], non_blocking=True
)

View File

@@ -169,9 +169,11 @@ class Attention(nn.Module):
else:
# Use ring buffer pipeline
o_acc, lse_acc = self._ring_buffer_pipeline_load(
q_batched, cpu_block_table, load_slots, offload_engine
q_batched, cpu_block_table, load_slots, offload_engine,
current_chunk_idx
)
# Compute attention against current chunk's KV (with causal mask)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
@@ -187,11 +189,18 @@ class Attention(nn.Module):
if o_acc is None:
final_o = current_o
else:
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
# reading it on the default stream for the merge operation.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
@@ -205,24 +214,27 @@ class Attention(nn.Module):
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
o_acc, lse_acc = None, None
compute_stream = offload_engine.compute_stream
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Load to slot 0 (single slot)
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(0, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
# IMPORTANT: Must use compute_stream to match wait_slot_layer
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(0, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
@@ -232,6 +244,7 @@ class Attention(nn.Module):
cpu_block_table: list,
load_slots: list,
offload_engine,
current_chunk_idx: int = -1,
):
"""
Ring buffer async pipeline loading with double buffering.
@@ -269,22 +282,26 @@ class Attention(nn.Module):
if pipeline_depth == 1:
# Only 1 slot available, cannot pipeline - use synchronous mode
# IMPORTANT: Must use compute_stream to match synchronization in
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
slot = load_slots[0]
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_table[block_idx])
offload_engine.wait_slot_layer(slot, self.layer_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot, self.layer_id)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(slot, self.layer_id)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot, self.layer_id)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
# N-way pipeline: use ALL available slots for maximum overlap
@@ -378,12 +395,13 @@ class Attention(nn.Module):
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq
# Get all CPU blocks for this sequence
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
# Get only PREFILLED CPU blocks (exclude the current decode block)
# The decode block's KV is still in GPU decode_slot, not yet offloaded to CPU
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if self.layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None:
@@ -401,12 +419,17 @@ class Attention(nn.Module):
)
offload_engine = kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
# Chunk size = capacity of each double buffer region (compute/prefetch)
# Each region uses half of decode_load_slots
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
# Check if double buffering is possible (need at least 2 separate regions)
# With only 1 load slot, compute and prefetch regions overlap -> can't double buffer
can_double_buffer = len(offload_engine.decode_load_slots) >= 2
o_acc = None
lse_acc = None
@@ -422,49 +445,53 @@ class Attention(nn.Module):
end = min(start + chunk_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Wait for current buffer to be ready
if use_compute:
offload_engine.wait_compute_layer(self.layer_id)
else:
offload_engine.wait_prefetch_layer(self.layer_id)
# Wait for current buffer to be ready on compute_stream
# The load runs on transfer_stream_main, compute runs on compute_stream
compute_stream.wait_stream(offload_engine.transfer_stream_main)
# Trigger async prefetch of next chunk to the OTHER buffer
# This overlaps transfer with current chunk's computation
# All computation on explicit compute_stream
with torch.cuda.stream(compute_stream):
# Get KV from current buffer FIRST, before prefetching overwrites it
if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
)
else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Trigger async prefetch/load of next chunk to the OTHER buffer
# This happens AFTER attention completes, so the data is no longer needed
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + chunk_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if use_compute:
# Current in Compute, prefetch next to Prefetch region
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
if can_double_buffer:
if use_compute:
# Current in Compute, prefetch next to Prefetch region
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
else:
# Current in Prefetch, prefetch next to Compute region
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
else:
# Current in Prefetch, prefetch next to Compute region
# Sync fallback: load next chunk to same slot (always compute region)
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
# Get KV from current buffer
if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
self.layer_id, num_blocks_in_chunk
)
else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Swap buffers for next iteration
# Swap buffers for next iteration (only matters if can_double_buffer)
use_compute = not use_compute
# Now attend to Decode region (contains accumulated decode tokens)
@@ -472,24 +499,29 @@ class Attention(nn.Module):
start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1
if num_accumulated > 0:
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0)
with torch.cuda.stream(compute_stream):
if num_accumulated > 0:
decode_k = offload_engine.k_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[self.layer_id, offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")
# Sync back to default stream before returning
# Caller expects result to be ready on default stream
torch.cuda.default_stream().wait_stream(compute_stream)
return o_acc