[WIP] Before fix bench_offload.py.
This commit is contained in:
@@ -32,8 +32,12 @@ def store_kvcache(
|
||||
"""
|
||||
# Filter out invalid slots (slot == -1)
|
||||
valid_mask = slot_mapping >= 0
|
||||
if not valid_mask.any():
|
||||
return
|
||||
|
||||
is_capturing = torch.cuda.is_current_stream_capturing()
|
||||
|
||||
if not is_capturing:
|
||||
if not valid_mask.any():
|
||||
return
|
||||
|
||||
valid_slots = slot_mapping[valid_mask]
|
||||
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
|
||||
@@ -51,6 +55,7 @@ def store_kvcache(
|
||||
valid_values_flat = valid_values.reshape(-1, D)
|
||||
|
||||
# In-place scatter using index_copy_
|
||||
# 即使 valid_slots 为空张量,index_copy_ 也是安全的(不会修改数据)。
|
||||
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
|
||||
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
|
||||
|
||||
@@ -86,7 +91,7 @@ class Attention(nn.Module):
|
||||
)
|
||||
|
||||
#! Ensure synchronization before accessing k_cache/v_cache
|
||||
torch.cuda.synchronize()
|
||||
# torch.cuda.synchronize()
|
||||
#! =======================================================
|
||||
|
||||
if is_chunked_offload:
|
||||
|
||||
Reference in New Issue
Block a user