[claudesquad] update from 'perf_opt-2' on 07 Jan 26 05:58 CST
This commit is contained in:
@@ -66,6 +66,12 @@ python bench_offload.py
|
|||||||
|
|
||||||
**Note**: The Python version in the path (python3.10) should match your environment.
|
**Note**: The Python version in the path (python3.10) should match your environment.
|
||||||
|
|
||||||
|
**CRITICAL**: After making code changes to `nanovllm/` source files, you MUST reinstall the package for changes to take effect:
|
||||||
|
```bash
|
||||||
|
pip install -e . --prefix=./.local --no-deps
|
||||||
|
```
|
||||||
|
Without reinstallation, Python will use the old cached version and your changes will NOT be reflected!
|
||||||
|
|
||||||
## Sparse Attention
|
## Sparse Attention
|
||||||
|
|
||||||
For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md).
|
For sparse attention related content (block sparse attention, MInference, FlexPrefill, XAttention, AvgPool, etc.), refer to [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md).
|
||||||
|
|||||||
@@ -455,8 +455,6 @@ class ModelRunner:
|
|||||||
3. After each chunk, offload from ring buffer slot to CPU
|
3. After each chunk, offload from ring buffer slot to CPU
|
||||||
4. All N-1 other slots are used to load previous chunks for attention
|
4. All N-1 other slots are used to load previous chunks for attention
|
||||||
"""
|
"""
|
||||||
import sys
|
|
||||||
|
|
||||||
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
|
assert len(seqs) == 1, "Ring buffer prefill only supports single sequence"
|
||||||
seq = seqs[0]
|
seq = seqs[0]
|
||||||
|
|
||||||
@@ -466,10 +464,9 @@ class ModelRunner:
|
|||||||
|
|
||||||
total_tokens = len(seq)
|
total_tokens = len(seq)
|
||||||
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
|
num_chunks = (total_tokens + tokens_per_chunk - 1) // tokens_per_chunk
|
||||||
print(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
|
logger.debug(f"[Ring Buffer Prefill] Starting: {total_tokens} tokens, "
|
||||||
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
|
f"ring_slots={offload_engine.num_ring_slots}, chunk={tokens_per_chunk} tokens, "
|
||||||
f"total_chunks={num_chunks}",
|
f"total_chunks={num_chunks}")
|
||||||
file=sys.stderr)
|
|
||||||
|
|
||||||
chunk_idx = 0
|
chunk_idx = 0
|
||||||
logits = None
|
logits = None
|
||||||
@@ -488,9 +485,8 @@ class ModelRunner:
|
|||||||
# CPU block index for this chunk
|
# CPU block index for this chunk
|
||||||
block_idx = chunk_idx
|
block_idx = chunk_idx
|
||||||
|
|
||||||
print(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
|
logger.debug(f"[Ring Buffer Prefill] Chunk {chunk_idx}: tokens {chunk_start}-{chunk_end}, "
|
||||||
f"write_slot={write_slot}",
|
f"write_slot={write_slot}")
|
||||||
file=sys.stderr)
|
|
||||||
|
|
||||||
# Prepare inputs
|
# Prepare inputs
|
||||||
input_ids, positions = self._prepare_chunked_offload_chunk(
|
input_ids, positions = self._prepare_chunked_offload_chunk(
|
||||||
@@ -509,27 +505,17 @@ class ModelRunner:
|
|||||||
logical_id = seq.block_table[block_idx]
|
logical_id = seq.block_table[block_idx]
|
||||||
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
self.kvcache_manager.prefilled_blocks.add(logical_id)
|
||||||
|
|
||||||
# NOTE: Per-layer offloading is now done in attention.forward
|
# NOTE: Per-layer async offloading is now done in attention.forward
|
||||||
# Each layer offloads its KV to CPU immediately after computing attention.
|
# Each layer offloads from its own prefill buffer - no waiting required!
|
||||||
# We just need to wait for the last offload to complete before reusing the slot.
|
# The sparse policy hook is called in offload_prefill_buffer_async.
|
||||||
if block_idx < len(cpu_block_ids):
|
|
||||||
# TODO: Sparse policy hook needs update for new GPU cache architecture
|
|
||||||
# The GPU cache no longer has layer dimension, so we can't access
|
|
||||||
# k_cache_gpu[layer_id, write_slot]. Sparse policy should be called
|
|
||||||
# in attention.forward after per-layer offload.
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Wait for offload to complete before next chunk
|
|
||||||
# (slot will be reused after N chunks)
|
|
||||||
offload_engine.wait_slot_offload(write_slot)
|
|
||||||
|
|
||||||
processed_tokens = chunk_end
|
processed_tokens = chunk_end
|
||||||
chunk_idx += 1
|
chunk_idx += 1
|
||||||
|
|
||||||
# Wait for all offloads to complete
|
# Wait for all async prefill offloads to complete
|
||||||
offload_engine.wait_all_offload_done()
|
offload_engine.wait_all_prefill_offloads()
|
||||||
|
|
||||||
print(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks", file=sys.stderr)
|
logger.debug(f"[Ring Buffer Prefill] Complete: {chunk_idx} chunks")
|
||||||
|
|
||||||
# Sample from last logits
|
# Sample from last logits
|
||||||
# For chunked prefill, ParallelLMHead automatically selects last position's logits
|
# For chunked prefill, ParallelLMHead automatically selects last position's logits
|
||||||
|
|||||||
@@ -142,6 +142,30 @@ class OffloadEngine:
|
|||||||
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
decode_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
logger.info(f" Per-layer decode buffer: {decode_buf_mb:.1f} MB")
|
||||||
|
|
||||||
|
# ========== Per-layer prefill buffer for async offload ==========
|
||||||
|
# During chunked prefill, all layers share the same GPU slot. This means
|
||||||
|
# each layer must wait for offload to complete before the next layer can
|
||||||
|
# write to the same slot. This serializes offloads and hurts performance.
|
||||||
|
# Solution: Maintain separate per-layer buffers for prefill.
|
||||||
|
# Each layer writes to its own buffer, enabling fully async offloads.
|
||||||
|
# Shape: [num_layers, block_size, kv_heads, head_dim]
|
||||||
|
self.prefill_k_buffer = torch.zeros(
|
||||||
|
num_layers, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
self.prefill_v_buffer = torch.zeros(
|
||||||
|
num_layers, block_size, num_kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
prefill_buf_mb = 2 * num_layers * block_size * num_kv_heads * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
|
logger.info(f" Per-layer prefill buffer: {prefill_buf_mb:.1f} MB")
|
||||||
|
|
||||||
|
# Per-layer offload events for async prefill offload
|
||||||
|
# Each layer has its own event to track offload completion
|
||||||
|
self.prefill_offload_events = [torch.cuda.Event() for _ in range(num_layers)]
|
||||||
|
# Per-layer transfer streams for parallel offloads
|
||||||
|
self.prefill_offload_streams = [torch.cuda.Stream() for _ in range(num_layers)]
|
||||||
|
|
||||||
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
# ========== Fixed-address CPU KV cache (pinned memory) ==========
|
||||||
self.k_cache_cpu = torch.zeros(
|
self.k_cache_cpu = torch.zeros(
|
||||||
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
num_layers, num_cpu_blocks, block_size, num_kv_heads, head_dim,
|
||||||
@@ -1063,4 +1087,92 @@ class OffloadEngine:
|
|||||||
# Allow pdb quit to propagate
|
# Allow pdb quit to propagate
|
||||||
if e.__class__.__name__ == 'BdbQuit':
|
if e.__class__.__name__ == 'BdbQuit':
|
||||||
raise
|
raise
|
||||||
logger.warning(f"Debug hook error: {e}")
|
logger.warning(f"Debug hook error: {e}")
|
||||||
|
|
||||||
|
# ========== Per-layer Prefill Buffer Methods ==========
|
||||||
|
# These methods enable async offload during chunked prefill by using
|
||||||
|
# per-layer buffers instead of shared GPU slots.
|
||||||
|
|
||||||
|
def get_prefill_buffer(self, layer_id: int) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Get prefill buffer for a layer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k_buffer, v_buffer), shape: [block_size, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
return self.prefill_k_buffer[layer_id], self.prefill_v_buffer[layer_id]
|
||||||
|
|
||||||
|
def get_prefill_buffer_slice(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
num_tokens: int,
|
||||||
|
) -> Tuple[Tensor, Tensor]:
|
||||||
|
"""
|
||||||
|
Get a slice of prefill buffer for attention computation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
num_tokens: Number of valid tokens in current chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(k, v) with shape [1, num_tokens, kv_heads, head_dim]
|
||||||
|
"""
|
||||||
|
k = self.prefill_k_buffer[layer_id, :num_tokens].unsqueeze(0)
|
||||||
|
v = self.prefill_v_buffer[layer_id, :num_tokens].unsqueeze(0)
|
||||||
|
return k, v
|
||||||
|
|
||||||
|
def offload_prefill_buffer_async(
|
||||||
|
self,
|
||||||
|
layer_id: int,
|
||||||
|
cpu_block_id: int,
|
||||||
|
num_valid_tokens: int = -1,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Async offload prefill buffer to CPU (no waiting required).
|
||||||
|
|
||||||
|
This uses per-layer streams and events to enable fully async offloads.
|
||||||
|
Each layer can offload independently without blocking other layers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
layer_id: Layer index
|
||||||
|
cpu_block_id: Target CPU block ID
|
||||||
|
num_valid_tokens: Number of valid tokens (-1 = use block_size)
|
||||||
|
"""
|
||||||
|
valid_tokens = num_valid_tokens if num_valid_tokens > 0 else self.block_size
|
||||||
|
|
||||||
|
# Collect sparse policy metadata before offload
|
||||||
|
if self.sparse_policy is not None:
|
||||||
|
k_cache = self.prefill_k_buffer[layer_id]
|
||||||
|
self.sparse_policy.on_prefill_offload(cpu_block_id, layer_id, k_cache, valid_tokens)
|
||||||
|
|
||||||
|
# Use per-layer stream for parallel offloads
|
||||||
|
stream = self.prefill_offload_streams[layer_id]
|
||||||
|
|
||||||
|
torch.cuda.nvtx.range_push(f"AsyncPrefillOffload: L{layer_id}->CPU[{cpu_block_id}]")
|
||||||
|
with torch.cuda.stream(stream):
|
||||||
|
# Wait for compute to finish writing to prefill buffer
|
||||||
|
stream.wait_stream(self.compute_stream)
|
||||||
|
|
||||||
|
# Copy from prefill buffer to CPU
|
||||||
|
self.k_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||||
|
self.prefill_k_buffer[layer_id], non_blocking=True
|
||||||
|
)
|
||||||
|
self.v_cache_cpu[layer_id, cpu_block_id].copy_(
|
||||||
|
self.prefill_v_buffer[layer_id], non_blocking=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record completion event
|
||||||
|
self.prefill_offload_events[layer_id].record(stream)
|
||||||
|
torch.cuda.nvtx.range_pop()
|
||||||
|
|
||||||
|
def wait_all_prefill_offloads(self) -> None:
|
||||||
|
"""Wait for all prefill buffer offloads to complete."""
|
||||||
|
for stream in self.prefill_offload_streams:
|
||||||
|
stream.synchronize()
|
||||||
|
|
||||||
|
def wait_prefill_offload(self, layer_id: int) -> None:
|
||||||
|
"""Wait for a specific layer's prefill offload to complete."""
|
||||||
|
self.prefill_offload_events[layer_id].synchronize()
|
||||||
@@ -99,8 +99,23 @@ class Attention(nn.Module):
|
|||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
#! =======================================================
|
#! =======================================================
|
||||||
|
|
||||||
if is_chunked_offload:
|
if is_chunked_offload and context.is_prefill:
|
||||||
# Chunked offload mode: use compute_stream for store_kvcache
|
# Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot)
|
||||||
|
# This enables fully async offloads since each layer has its own buffer.
|
||||||
|
offload_engine = context.kvcache_manager.offload_engine
|
||||||
|
compute_stream = offload_engine.compute_stream
|
||||||
|
|
||||||
|
# Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||||||
|
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||||
|
|
||||||
|
with torch.cuda.stream(compute_stream):
|
||||||
|
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
|
||||||
|
# k, v shape: [num_tokens, kv_heads, head_dim]
|
||||||
|
num_tokens = k.shape[0]
|
||||||
|
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
|
||||||
|
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
|
||||||
|
elif is_chunked_offload:
|
||||||
|
# Chunked decode mode: use compute_stream for store_kvcache
|
||||||
# This ensures proper synchronization with per-layer offload
|
# This ensures proper synchronization with per-layer offload
|
||||||
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
compute_stream = context.kvcache_manager.offload_engine.compute_stream
|
||||||
if k_cache.numel() and v_cache.numel():
|
if k_cache.numel() and v_cache.numel():
|
||||||
@@ -157,36 +172,36 @@ class Attention(nn.Module):
|
|||||||
context,
|
context,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute attention with unified ring buffer for chunked prefill.
|
Compute attention with per-layer prefill buffer for async offload.
|
||||||
|
|
||||||
Ring buffer design:
|
Optimized design:
|
||||||
- Current chunk's KV is written to ring_slot[chunk_idx % N]
|
- Current chunk's KV is written to per-layer prefill buffer (not GPU slot)
|
||||||
- Previous chunks' KV are loaded from CPU using N-1 available slots
|
- Previous chunks' KV are loaded from CPU using GPU slots
|
||||||
- Pipeline: pre-fill slots, then process with overlapped load/compute
|
- Each layer offloads from its own buffer - no waiting required!
|
||||||
|
|
||||||
For each layer:
|
For each layer:
|
||||||
1. Current chunk's KV is in k_batched, v_batched (just written by model)
|
1. Current chunk's KV is in prefill_buffer[layer_id] (just written by model)
|
||||||
2. Load previous chunks from CPU using available slots (pipeline)
|
2. Load previous chunks from CPU using available slots (pipeline)
|
||||||
3. Compute attention against previous KV (no causal mask)
|
3. Compute attention against previous KV (no causal mask)
|
||||||
4. Compute attention against current KV (causal)
|
4. Compute attention against current KV from prefill buffer (causal)
|
||||||
5. Merge all results using online softmax
|
5. Merge all results using online softmax
|
||||||
|
6. Async offload prefill buffer to CPU (no waiting!)
|
||||||
"""
|
"""
|
||||||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||||
|
|
||||||
current_chunk_idx = context.current_chunk_idx
|
current_chunk_idx = context.current_chunk_idx
|
||||||
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
||||||
|
|
||||||
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
# q shape: [total_tokens, num_heads, head_dim]
|
||||||
# Reshape for flash attention: [batch, seq, heads, dim]
|
|
||||||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||||||
k_batched = k.unsqueeze(0)
|
num_tokens = k.shape[0]
|
||||||
v_batched = v.unsqueeze(0)
|
|
||||||
|
|
||||||
o_acc = None
|
o_acc = None
|
||||||
lse_acc = None
|
lse_acc = None
|
||||||
|
|
||||||
kvcache_manager = context.kvcache_manager
|
kvcache_manager = context.kvcache_manager
|
||||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||||||
|
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
|
||||||
|
|
||||||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||||||
# Get prefilled CPU blocks (blocks from previous chunks)
|
# Get prefilled CPU blocks (blocks from previous chunks)
|
||||||
@@ -210,11 +225,8 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cpu_block_table:
|
if cpu_block_table:
|
||||||
offload_engine = kvcache_manager.offload_engine
|
# Get available load slots (all slots can be used since we use prefill buffer)
|
||||||
|
load_slots = list(range(offload_engine.num_ring_slots))
|
||||||
# Get write slot for current chunk and available load slots
|
|
||||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
|
||||||
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
|
|
||||||
pipeline_depth = len(load_slots)
|
pipeline_depth = len(load_slots)
|
||||||
|
|
||||||
if pipeline_depth == 0:
|
if pipeline_depth == 0:
|
||||||
@@ -230,15 +242,14 @@ class Attention(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get compute stream for all attention operations
|
# Get compute stream for all attention operations
|
||||||
compute_stream = None
|
compute_stream = offload_engine.compute_stream if offload_engine is not None else None
|
||||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
|
||||||
compute_stream = kvcache_manager.offload_engine.compute_stream
|
|
||||||
|
|
||||||
# Compute attention against current chunk's KV (with causal mask)
|
# Compute attention against current chunk's KV from prefill buffer (with causal mask)
|
||||||
# Use compute_stream to ensure proper sync with store_kvcache and offload
|
|
||||||
if compute_stream is not None:
|
if compute_stream is not None:
|
||||||
with torch.cuda.stream(compute_stream):
|
with torch.cuda.stream(compute_stream):
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||||
|
# Get KV from per-layer prefill buffer
|
||||||
|
k_batched, v_batched = offload_engine.get_prefill_buffer_slice(self.layer_id, num_tokens)
|
||||||
current_o, current_lse = flash_attn_with_lse(
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
q_batched,
|
q_batched,
|
||||||
k_batched,
|
k_batched,
|
||||||
@@ -249,6 +260,8 @@ class Attention(nn.Module):
|
|||||||
torch.cuda.nvtx.range_pop()
|
torch.cuda.nvtx.range_pop()
|
||||||
else:
|
else:
|
||||||
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
|
||||||
|
k_batched = k.unsqueeze(0)
|
||||||
|
v_batched = v.unsqueeze(0)
|
||||||
current_o, current_lse = flash_attn_with_lse(
|
current_o, current_lse = flash_attn_with_lse(
|
||||||
q_batched,
|
q_batched,
|
||||||
k_batched,
|
k_batched,
|
||||||
@@ -274,26 +287,16 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||||||
|
|
||||||
# Per-layer offload: In new GPU cache architecture (no layer dimension),
|
# Per-layer ASYNC offload: offload prefill buffer to CPU
|
||||||
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
|
# No waiting required! Each layer has its own buffer and stream.
|
||||||
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
|
if offload_engine is not None and seq is not None:
|
||||||
offload_engine = kvcache_manager.offload_engine
|
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||||||
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
|
if current_chunk_idx < len(cpu_block_ids):
|
||||||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||||||
if seq is not None:
|
# Async offload - no waiting, fully parallel across layers
|
||||||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
offload_engine.offload_prefill_buffer_async(
|
||||||
if current_chunk_idx < len(cpu_block_ids):
|
self.layer_id, cpu_block_id, num_tokens
|
||||||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
)
|
||||||
# k.shape[0] = number of tokens in current chunk
|
|
||||||
num_valid_tokens = k.shape[0]
|
|
||||||
offload_engine.offload_slot_layer_to_cpu(
|
|
||||||
write_slot, self.layer_id, cpu_block_id, num_valid_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
# CRITICAL: compute_stream must wait for offload to complete
|
|
||||||
# before the next layer's store_kvcache can overwrite the GPU slot.
|
|
||||||
# Without this, Layer N+1's store races with Layer N's offload copy.
|
|
||||||
compute_stream.wait_event(offload_engine.ring_slot_offload_done[write_slot])
|
|
||||||
|
|
||||||
# Sync default stream with compute_stream before returning
|
# Sync default stream with compute_stream before returning
|
||||||
# This ensures the result is ready for the rest of the model (layernorm, MLP)
|
# This ensures the result is ready for the rest of the model (layernorm, MLP)
|
||||||
|
|||||||
Reference in New Issue
Block a user