[WIP] need change flashattention to debug.
This commit is contained in:
@@ -1007,9 +1007,8 @@ class OffloadEngine:
|
||||
if not self._debug_mode or not self._debug_hooks:
|
||||
return
|
||||
|
||||
# GPU cache has no layer dimension
|
||||
k = self.k_cache_gpu[slot_idx]
|
||||
v = self.v_cache_gpu[slot_idx]
|
||||
# Use get_kv_for_slot for consistency with attention.py
|
||||
k, v = self.get_kv_for_slot(slot_idx)
|
||||
|
||||
for hook in self._debug_hooks:
|
||||
try:
|
||||
|
||||
@@ -426,6 +426,14 @@ class Attention(nn.Module):
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||
|
||||
# Calculate valid tokens in the last block
|
||||
# prefill_len = total prefilled tokens (current decode token not yet in CPU)
|
||||
block_size = kvcache_manager.block_size
|
||||
prefill_len = len(seq) - 1 # Exclude current decode token
|
||||
last_block_valid_tokens = prefill_len % block_size
|
||||
if last_block_valid_tokens == 0 and prefill_len > 0:
|
||||
last_block_valid_tokens = block_size # Last block is full
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
if kvcache_manager.sparse_policy is not None:
|
||||
policy_ctx = PolicyContext(
|
||||
@@ -480,6 +488,18 @@ class Attention(nn.Module):
|
||||
else:
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk)
|
||||
|
||||
# Handle partial last block: slice to only include valid tokens
|
||||
# This is critical because the rest of the block contains stale data
|
||||
is_last_chunk = (end == len(cpu_block_table))
|
||||
if is_last_chunk and last_block_valid_tokens < block_size:
|
||||
# Calculate total valid tokens in this chunk
|
||||
# All blocks except the last are full, last block has last_block_valid_tokens
|
||||
full_blocks = num_blocks_in_chunk - 1
|
||||
valid_tokens = full_blocks * block_size + last_block_valid_tokens
|
||||
# Slice KV: [batch, seqlen, heads, dim] -> [batch, valid_tokens, heads, dim]
|
||||
k_chunk = k_chunk[:, :valid_tokens, :, :]
|
||||
v_chunk = v_chunk[:, :valid_tokens, :, :]
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
q_batched, k_chunk, v_chunk,
|
||||
@@ -518,6 +538,11 @@ class Attention(nn.Module):
|
||||
start_pos = context.decode_start_pos_in_block
|
||||
num_accumulated = pos_in_block - start_pos + 1
|
||||
|
||||
# IMPORTANT: Sync compute_stream with default stream before reading decode_slot
|
||||
# store_kvcache writes to decode_slot on default stream (before entering this function)
|
||||
# We need to ensure that write is complete before reading on compute_stream
|
||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if num_accumulated > 0:
|
||||
# GPU cache has no layer dimension
|
||||
|
||||
Reference in New Issue
Block a user