[WIP] need change flashattention to debug.

This commit is contained in:
Zijie Tian
2026-01-01 00:58:22 +08:00
parent 30462fe89a
commit 965c8aff12
3 changed files with 49 additions and 3 deletions

View File

@@ -1007,9 +1007,8 @@ class OffloadEngine:
if not self._debug_mode or not self._debug_hooks: if not self._debug_mode or not self._debug_hooks:
return return
# GPU cache has no layer dimension # Use get_kv_for_slot for consistency with attention.py
k = self.k_cache_gpu[slot_idx] k, v = self.get_kv_for_slot(slot_idx)
v = self.v_cache_gpu[slot_idx]
for hook in self._debug_hooks: for hook in self._debug_hooks:
try: try:

View File

@@ -426,6 +426,14 @@ class Attention(nn.Module):
if not cpu_block_table: if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available") raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last block
# prefill_len = total prefilled tokens (current decode token not yet in CPU)
block_size = kvcache_manager.block_size
prefill_len = len(seq) - 1 # Exclude current decode token
last_block_valid_tokens = prefill_len % block_size
if last_block_valid_tokens == 0 and prefill_len > 0:
last_block_valid_tokens = block_size # Last block is full
# Apply sparse policy if enabled # Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None: if kvcache_manager.sparse_policy is not None:
policy_ctx = PolicyContext( policy_ctx = PolicyContext(
@@ -480,6 +488,18 @@ class Attention(nn.Module):
else: else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk) k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk)
# Handle partial last block: slice to only include valid tokens
# This is critical because the rest of the block contains stale data
is_last_chunk = (end == len(cpu_block_table))
if is_last_chunk and last_block_valid_tokens < block_size:
# Calculate total valid tokens in this chunk
# All blocks except the last are full, last block has last_block_valid_tokens
full_blocks = num_blocks_in_chunk - 1
valid_tokens = full_blocks * block_size + last_block_valid_tokens
# Slice KV: [batch, seqlen, heads, dim] -> [batch, valid_tokens, heads, dim]
k_chunk = k_chunk[:, :valid_tokens, :, :]
v_chunk = v_chunk[:, :valid_tokens, :, :]
# Compute attention for this chunk # Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse( o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk, q_batched, k_chunk, v_chunk,
@@ -518,6 +538,11 @@ class Attention(nn.Module):
start_pos = context.decode_start_pos_in_block start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1 num_accumulated = pos_in_block - start_pos + 1
# IMPORTANT: Sync compute_stream with default stream before reading decode_slot
# store_kvcache writes to decode_slot on default stream (before entering this function)
# We need to ensure that write is complete before reading on compute_stream
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
if num_accumulated > 0: if num_accumulated > 0:
# GPU cache has no layer dimension # GPU cache has no layer dimension

View File

@@ -6,6 +6,7 @@ Injects distinctive K/V values, verifies loaded tensors match expected patterns.
import os import os
os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING" os.environ["NANOVLLM_LOG_LEVEL"] = "WARNING"
import inspect
from random import randint, seed from random import randint, seed
from typing import Dict, List from typing import Dict, List
import torch import torch
@@ -30,6 +31,27 @@ def debug_load_hook(slot_idx: int, layer_id: int, cpu_block_id: int, k: Tensor,
if layer_id != 0: if layer_id != 0:
return return
# Go up the stack to find kvcache_manager and print k_cache_gpu[*][0,0,0] for all slots
frame = inspect.currentframe()
try:
caller_frame = frame.f_back
if caller_frame is not None:
local_vars = caller_frame.f_locals
if 'self' in local_vars:
self_obj = local_vars['self']
if hasattr(self_obj, 'k_cache_gpu'):
num_slots = self_obj.k_cache_gpu.shape[0]
vals = []
for i in range(num_slots):
v = self_obj.k_cache_gpu[i][0,0,0].item()
if i == slot_idx:
vals.append(f"[{v}]")
else:
vals.append(str(v))
print(f"[DEBUG] k_cache_gpu[0..{num_slots-1}][0,0,0] = [{', '.join(vals)}]")
finally:
del frame
load_log.append({ load_log.append({
"chunk_idx": current_chunk[0], "chunk_idx": current_chunk[0],
"cpu_block_id": cpu_block_id, "cpu_block_id": cpu_block_id,