diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index 11f4516..de966b7 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -30,18 +30,23 @@ def store_kvcache( v_cache: same shape as k_cache slot_mapping: [N] with values as flat indices, -1 means skip """ - # Filter out invalid slots (slot == -1) - valid_mask = slot_mapping >= 0 - is_capturing = torch.cuda.is_current_stream_capturing() - if not is_capturing: + if is_capturing: + # During CUDA graph capture, assume all slots are valid. + # CUDA graphs don't support data-dependent operations like boolean indexing. + # This is safe because decode (captured) always has valid slots. + valid_slots = slot_mapping + valid_keys = key + valid_values = value + else: + # Normal execution: filter out invalid slots (slot == -1) + valid_mask = slot_mapping >= 0 if not valid_mask.any(): return - - valid_slots = slot_mapping[valid_mask] - valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim] - valid_values = value[valid_mask] + valid_slots = slot_mapping[valid_mask] + valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim] + valid_values = value[valid_mask] # Flatten cache and KV for scatter operation # Cache is viewed as [total_slots, D] where D = num_kv_heads * head_dim