[fix] Fixed bench_offload.py, BUT performance DEGRAD.

This commit is contained in:
Zijie Tian
2026-01-06 18:46:48 +08:00
parent 535f2037ab
commit 7cc8a394a5

View File

@@ -30,18 +30,23 @@ def store_kvcache(
v_cache: same shape as k_cache
slot_mapping: [N] with values as flat indices, -1 means skip
"""
# Filter out invalid slots (slot == -1)
valid_mask = slot_mapping >= 0
is_capturing = torch.cuda.is_current_stream_capturing()
if not is_capturing:
if is_capturing:
# During CUDA graph capture, assume all slots are valid.
# CUDA graphs don't support data-dependent operations like boolean indexing.
# This is safe because decode (captured) always has valid slots.
valid_slots = slot_mapping
valid_keys = key
valid_values = value
else:
# Normal execution: filter out invalid slots (slot == -1)
valid_mask = slot_mapping >= 0
if not valid_mask.any():
return
valid_slots = slot_mapping[valid_mask]
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
valid_values = value[valid_mask]
valid_slots = slot_mapping[valid_mask]
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
valid_values = value[valid_mask]
# Flatten cache and KV for scatter operation
# Cache is viewed as [total_slots, D] where D = num_kv_heads * head_dim