[WIP] need change flashattention to debug.
This commit is contained in:
@@ -1007,9 +1007,8 @@ class OffloadEngine:
|
||||
if not self._debug_mode or not self._debug_hooks:
|
||||
return
|
||||
|
||||
# GPU cache has no layer dimension
|
||||
k = self.k_cache_gpu[slot_idx]
|
||||
v = self.v_cache_gpu[slot_idx]
|
||||
# Use get_kv_for_slot for consistency with attention.py
|
||||
k, v = self.get_kv_for_slot(slot_idx)
|
||||
|
||||
for hook in self._debug_hooks:
|
||||
try:
|
||||
|
||||
@@ -426,6 +426,14 @@ class Attention(nn.Module):
|
||||
if not cpu_block_table:
|
||||
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
|
||||
|
||||
# Calculate valid tokens in the last block
|
||||
# prefill_len = total prefilled tokens (current decode token not yet in CPU)
|
||||
block_size = kvcache_manager.block_size
|
||||
prefill_len = len(seq) - 1 # Exclude current decode token
|
||||
last_block_valid_tokens = prefill_len % block_size
|
||||
if last_block_valid_tokens == 0 and prefill_len > 0:
|
||||
last_block_valid_tokens = block_size # Last block is full
|
||||
|
||||
# Apply sparse policy if enabled
|
||||
if kvcache_manager.sparse_policy is not None:
|
||||
policy_ctx = PolicyContext(
|
||||
@@ -480,6 +488,18 @@ class Attention(nn.Module):
|
||||
else:
|
||||
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk)
|
||||
|
||||
# Handle partial last block: slice to only include valid tokens
|
||||
# This is critical because the rest of the block contains stale data
|
||||
is_last_chunk = (end == len(cpu_block_table))
|
||||
if is_last_chunk and last_block_valid_tokens < block_size:
|
||||
# Calculate total valid tokens in this chunk
|
||||
# All blocks except the last are full, last block has last_block_valid_tokens
|
||||
full_blocks = num_blocks_in_chunk - 1
|
||||
valid_tokens = full_blocks * block_size + last_block_valid_tokens
|
||||
# Slice KV: [batch, seqlen, heads, dim] -> [batch, valid_tokens, heads, dim]
|
||||
k_chunk = k_chunk[:, :valid_tokens, :, :]
|
||||
v_chunk = v_chunk[:, :valid_tokens, :, :]
|
||||
|
||||
# Compute attention for this chunk
|
||||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||||
q_batched, k_chunk, v_chunk,
|
||||
@@ -518,6 +538,11 @@ class Attention(nn.Module):
|
||||
start_pos = context.decode_start_pos_in_block
|
||||
num_accumulated = pos_in_block - start_pos + 1
|
||||
|
||||
# IMPORTANT: Sync compute_stream with default stream before reading decode_slot
|
||||
# store_kvcache writes to decode_slot on default stream (before entering this function)
|
||||
# We need to ensure that write is complete before reading on compute_stream
|
||||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if num_accumulated > 0:
|
||||
# GPU cache has no layer dimension
|
||||
|
||||
@@ -6,6 +6,7 @@ Injects distinctive K/V values, verifies loaded tensors match expected patterns.
|
||||
import os
|
||||
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
|
||||
|
||||
import inspect
|
||||
from random import randint, seed
|
||||
from typing import Dict, List
|
||||
import torch
|
||||
@@ -30,6 +31,27 @@ def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor,
|
||||
if layer_id != 0:
|
||||
return
|
||||
|
||||
# Go up the stack to find kvcache_manager and print k_cache_gpu[*][0,0,0] for all slots
|
||||
frame = inspect.currentframe()
|
||||
try:
|
||||
caller_frame = frame.f_back
|
||||
if caller_frame is not None:
|
||||
local_vars = caller_frame.f_locals
|
||||
if 'self' in local_vars:
|
||||
self_obj = local_vars['self']
|
||||
if hasattr(self_obj, 'k_cache_gpu'):
|
||||
num_slots = self_obj.k_cache_gpu.shape[0]
|
||||
vals = []
|
||||
for i in range(num_slots):
|
||||
v = self_obj.k_cache_gpu[i][0,0,0].item()
|
||||
if i == slot_idx:
|
||||
vals.append(f"[{v}]")
|
||||
else:
|
||||
vals.append(str(v))
|
||||
print(f"[DEBUG] k_cache_gpu[0..{num_slots-1}][0,0,0] = [{', '.join(vals)}]")
|
||||
finally:
|
||||
del frame
|
||||
|
||||
load_log.append({
|
||||
"chunk_idx": current_chunk[0],
|
||||
"cpu_block_id": cpu_block_id,
|
||||
|
||||
Reference in New Issue
Block a user