Merge branch 'zijie/debug_chunk-2' into tzj/minference

This commit is contained in:
Zijie Tian
2026-01-07 03:30:38 +08:00
3 changed files with 42 additions and 10 deletions

View File

@@ -104,6 +104,10 @@ class Attention(nn.Module):
# This ensures proper synchronization with per-layer offload
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
# slot_mapping is created with non_blocking=True on default stream, but we use it
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
else: