Files
nano-vllm/nanovllm/layers/attention.py
Zijie Tian fa7601f4b8 ♻️ refactor: remove cross-layer pipeline and rename compute_chunked_prefill
- Remove cross-layer pipeline from OffloadEngine (saves ~1GB GPU memory for long sequences)
  - Delete layer_k/v_buffer_a/b double buffers
  - Remove start_decode_pipeline, get_decode_layer_kv, end_decode_pipeline methods
  - Remove pipeline state tracking variables
- Simplify decode to use ring buffer pipeline only (more efficient for long sequences)
- Rename compute_chunked_attention → compute_chunked_prefill for clarity
- Add mandatory needle test requirements: --enable-offload --input-len 32768

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 02:10:40 +08:00

277 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import logging
import torch
import torch.cuda.nvtx
from torch import nn
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
logger = logging.getLogger(__name__)
def store_kvcache(
key: torch.Tensor,
value: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
slot_mapping: torch.Tensor,
):
"""
Store key/value tensors into KV cache using slot mapping.
This is a pure PyTorch implementation replacing the previous Triton kernel.
Uses index_copy_ for efficient in-place scatter operation.
Args:
key: [N, num_kv_heads, head_dim]
value: [N, num_kv_heads, head_dim]
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] or similar
v_cache: same shape as k_cache
slot_mapping: [N] with values as flat indices, -1 means skip
"""
is_capturing = torch.cuda.is_current_stream_capturing()
if is_capturing:
# During CUDA graph capture, assume all slots are valid.
# CUDA graphs don't support data-dependent operations like boolean indexing.
# This is safe because decode (captured) always has valid slots.
valid_slots = slot_mapping
valid_keys = key
valid_values = value
else:
# Normal execution: filter out invalid slots (slot == -1)
valid_mask = slot_mapping >= 0
if not valid_mask.any():
return
valid_slots = slot_mapping[valid_mask]
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
valid_values = value[valid_mask]
# Flatten cache and KV for scatter operation
# Cache is viewed as [total_slots, D] where D = num_kv_heads * head_dim
N, num_kv_heads, head_dim = key.shape
D = num_kv_heads * head_dim
total_slots = k_cache.numel() // D
k_cache_flat = k_cache.view(total_slots, D)
v_cache_flat = v_cache.view(total_slots, D)
valid_keys_flat = valid_keys.reshape(-1, D)
valid_values_flat = valid_values.reshape(-1, D)
# In-place scatter using index_copy_
# 即使 valid_slots 为空张量index_copy_ 也是安全的(不会修改数据)。
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
class Attention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
# Layer ID set by model_runner after model creation
self.layer_id: int = -1
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
# Determine if we're in chunked offload mode
is_chunked_offload = (
context.is_chunked_prefill and
hasattr(context, 'kvcache_manager') and
context.kvcache_manager is not None and
hasattr(context.kvcache_manager, 'offload_engine')
)
#! Ensure synchronization before accessing k_cache/v_cache
# torch.cuda.synchronize()
#! =======================================================
if is_chunked_offload and context.is_prefill:
# Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot)
# This enables fully async offloads since each layer has its own buffer.
offload_engine = context.kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
# Wait for default stream to ensure slot_mapping tensor transfer is complete
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
# Write KV to per-layer prefill buffer (contiguous write, no slot_mapping)
# k, v shape: [num_tokens, kv_heads, head_dim]
num_tokens = k.shape[0]
offload_engine.prefill_k_buffer[self.layer_id, :num_tokens].copy_(k)
offload_engine.prefill_v_buffer[self.layer_id, :num_tokens].copy_(v)
elif is_chunked_offload:
# Chunked decode mode: use compute_stream for store_kvcache
# This ensures proper synchronization with per-layer offload
compute_stream = context.kvcache_manager.offload_engine.compute_stream
if k_cache.numel() and v_cache.numel():
# CRITICAL: Wait for default stream to ensure slot_mapping tensor transfer is complete
# slot_mapping is created with non_blocking=True on default stream, but we use it
# on compute_stream. Without this sync, index_copy_ can get corrupted indices.
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
else:
# Normal mode: store on default stream
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.is_chunked_prefill:
# Chunked prefill: merge attention from previous KV
o = self._chunked_prefill_attention(q, k, v, context)
elif context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else:
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # decode
if context.is_chunked_prefill:
# Chunked decode: need to load all KV from CPU+GPU
# Store current decode token to per-layer decode buffer
# This is needed because GPU cache has no layer dimension,
# so all layers would overwrite each other in decode_slot.
kvcache_manager = context.kvcache_manager
offload_engine = kvcache_manager.offload_engine
pos_in_block = context.decode_pos_in_block
# k, v shape: [1, kv_heads, head_dim]
offload_engine.decode_k_buffer[self.layer_id, pos_in_block].copy_(k.squeeze(0))
offload_engine.decode_v_buffer[self.layer_id, pos_in_block].copy_(v.squeeze(0))
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
return o
def _chunked_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> torch.Tensor:
"""
Compute attention with per-layer prefill buffer for async offload.
Simplified design:
- All computation logic is delegated to sparse_policy.compute_chunked_prefill()
- This method only handles async offload after computation
The policy handles:
1. Loading historical blocks from CPU
2. Computing attention against historical KV (no causal mask)
3. Computing attention against current KV from prefill buffer (causal)
4. Merging all results
"""
current_chunk_idx = context.current_chunk_idx
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
num_tokens = k.shape[0]
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
# Get sparse policy - required for chunked prefill
sparse_policy = kvcache_manager.sparse_policy
if sparse_policy is None:
raise RuntimeError("sparse_policy is required for chunked prefill")
# [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
# Delegate all computation to policy (no flash_attn or merge calls here!)
final_o = sparse_policy.compute_chunked_prefill(
q, k, v,
self.layer_id,
self.scale,
offload_engine,
kvcache_manager,
current_chunk_idx,
seq,
num_tokens,
)
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Per-layer ASYNC offload: offload prefill buffer to CPU
# No waiting required! Each layer has its own buffer and stream.
if offload_engine is not None and seq is not None:
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
# Async offload - no waiting, fully parallel across layers
offload_engine.offload_prefill_buffer_async(
self.layer_id, cpu_block_id, num_tokens
)
return final_o
def _chunked_decode_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> torch.Tensor:
"""
Compute decode attention by delegating to sparse policy.
Simplified design:
- All computation logic is delegated to sparse_policy.compute_chunked_decode()
- This method only validates the policy and delegates
The policy handles:
1. Loading prefilled blocks from CPU via pipeline
2. Computing attention against prefilled KV
3. Reading accumulated decode tokens from decode buffer
4. Merging all results
"""
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq
offload_engine = kvcache_manager.offload_engine
# Get sparse policy - required for chunked decode
sparse_policy = kvcache_manager.sparse_policy
if sparse_policy is None:
raise RuntimeError("sparse_policy is required for chunked decode")
# Check if policy supports decode phase
if not sparse_policy.supports_decode:
raise RuntimeError(f"{sparse_policy} does not support decode phase")
# [DEBUG] Verify execution path
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
f"policy={sparse_policy}, layer={self.layer_id}")
# Delegate all computation to policy (no flash_attn or merge calls here!)
return sparse_policy.compute_chunked_decode(
q,
self.layer_id,
self.scale,
offload_engine,
kvcache_manager,
seq,
)