[fix] Fixed bench_offload.py, BUT performance DEGRAD.
This commit is contained in:
@@ -30,15 +30,20 @@ def store_kvcache(
|
|||||||
v_cache: same shape as k_cache
|
v_cache: same shape as k_cache
|
||||||
slot_mapping: [N] with values as flat indices, -1 means skip
|
slot_mapping: [N] with values as flat indices, -1 means skip
|
||||||
"""
|
"""
|
||||||
# Filter out invalid slots (slot == -1)
|
|
||||||
valid_mask = slot_mapping >= 0
|
|
||||||
|
|
||||||
is_capturing = torch.cuda.is_current_stream_capturing()
|
is_capturing = torch.cuda.is_current_stream_capturing()
|
||||||
|
|
||||||
if not is_capturing:
|
if is_capturing:
|
||||||
|
# During CUDA graph capture, assume all slots are valid.
|
||||||
|
# CUDA graphs don't support data-dependent operations like boolean indexing.
|
||||||
|
# This is safe because decode (captured) always has valid slots.
|
||||||
|
valid_slots = slot_mapping
|
||||||
|
valid_keys = key
|
||||||
|
valid_values = value
|
||||||
|
else:
|
||||||
|
# Normal execution: filter out invalid slots (slot == -1)
|
||||||
|
valid_mask = slot_mapping >= 0
|
||||||
if not valid_mask.any():
|
if not valid_mask.any():
|
||||||
return
|
return
|
||||||
|
|
||||||
valid_slots = slot_mapping[valid_mask]
|
valid_slots = slot_mapping[valid_mask]
|
||||||
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
|
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
|
||||||
valid_values = value[valid_mask]
|
valid_values = value[valid_mask]
|
||||||
|
|||||||
Reference in New Issue
Block a user