Files
nano-vllm/nanovllm/layers/attention.py
2026-01-01 00:58:22 +08:00

573 lines
26 KiB
Python

import logging
import torch
import torch.cuda.nvtx
from torch import nn
import triton
import triton.language as tl
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
from nanovllm.kvcache.sparse.policy import PolicyContext
logger = logging.getLogger(__name__)
@triton.jit
def store_kvcache_kernel(
key_ptr,
key_stride,
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr,
D: tl.constexpr,
):
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1: return
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets)
value = tl.load(value_ptr + value_offsets)
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key)
tl.store(v_cache_ptr + cache_offsets, value)
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
assert key.stride(1) == head_dim and value.stride(1) == head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D
assert slot_mapping.numel() == N
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
class Attention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
# Layer ID set by model_runner after model creation
self.layer_id: int = -1
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.is_chunked_prefill:
# Chunked prefill: merge attention from previous KV
o = self._chunked_prefill_attention(q, k, v, context)
elif context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else:
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # decode
if context.is_chunked_prefill:
# Chunked decode: need to load all KV from CPU+GPU
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
return o
def _chunked_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> torch.Tensor:
"""
Compute attention with unified ring buffer for chunked prefill.
Ring buffer design:
- Current chunk's KV is written to ring_slot[chunk_idx % N]
- Previous chunks' KV are loaded from CPU using N-1 available slots
- Pipeline: pre-fill slots, then process with overlapped load/compute
For each layer:
1. Current chunk's KV is in k_batched, v_batched (just written by model)
2. Load previous chunks from CPU using available slots (pipeline)
3. Compute attention against previous KV (no causal mask)
4. Compute attention against current KV (causal)
5. Merge all results using online softmax
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
current_chunk_idx = context.current_chunk_idx
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
# q, k, v shape: [total_tokens, num_heads, head_dim]
# Reshape for flash attention: [batch, seq, heads, dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
o_acc = None
lse_acc = None
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks from previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
# Apply sparse policy if enabled
if cpu_block_table and kvcache_manager.sparse_policy is not None:
num_chunks = getattr(context, 'num_chunks', current_chunk_idx + 1)
policy_ctx = PolicyContext(
query_chunk_idx=current_chunk_idx,
num_query_chunks=num_chunks,
layer_id=self.layer_id,
query=None, # Prefill typically doesn't use query for selection
is_prefill=True,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
# Get write slot for current chunk and available load slots
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
load_slots = offload_engine.get_load_slots_for_prefill(write_slot)
pipeline_depth = len(load_slots)
if pipeline_depth == 0:
# Only 1 slot total, cannot pipeline - use sync loading
o_acc, lse_acc = self._sync_load_previous_chunks(
q_batched, cpu_block_table, offload_engine
)
else:
# Use ring buffer pipeline
o_acc, lse_acc = self._ring_buffer_pipeline_load(
q_batched, cpu_block_table, load_slots, offload_engine,
current_chunk_idx
)
# Compute attention against current chunk's KV (with causal mask)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} CurrentChunk (causal)")
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
torch.cuda.nvtx.range_pop()
# Merge with accumulated
if o_acc is None:
final_o = current_o
else:
# IMPORTANT: o_acc was computed on compute_stream. We need to sync before
# reading it on the default stream for the merge operation.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
torch.cuda.default_stream().wait_stream(offload_engine.compute_stream)
torch.cuda.nvtx.range_push(f"MergeAttn: L{self.layer_id}")
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
torch.cuda.nvtx.range_pop()
torch.cuda.nvtx.range_pop() # ChunkedPrefill
# Per-layer offload: In new GPU cache architecture (no layer dimension),
# each layer must offload its KV to CPU before next layer overwrites the GPU slot.
if kvcache_manager is not None and hasattr(kvcache_manager, 'offload_engine'):
offload_engine = kvcache_manager.offload_engine
write_slot = offload_engine.get_write_slot_for_prefill(current_chunk_idx)
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
if seq is not None:
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
if current_chunk_idx < len(cpu_block_ids):
cpu_block_id = cpu_block_ids[current_chunk_idx]
offload_engine.offload_slot_layer_to_cpu(write_slot, self.layer_id, cpu_block_id)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
def _sync_load_previous_chunks(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
offload_engine,
):
"""Synchronous loading fallback when pipeline_depth=0."""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
o_acc, lse_acc = None, None
compute_stream = offload_engine.compute_stream
for block_idx, cpu_block_id in enumerate(cpu_block_table):
# Load to slot 0 (single slot)
offload_engine.load_to_slot_layer(0, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(0)
# IMPORTANT: Must use compute_stream to match wait_slot_layer
with torch.cuda.stream(compute_stream):
prev_k, prev_v = offload_engine.get_kv_for_slot(0)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
def _ring_buffer_pipeline_load(
self,
q_batched: torch.Tensor,
cpu_block_table: list,
load_slots: list,
offload_engine,
current_chunk_idx: int = -1,
):
"""
Ring buffer async pipeline loading with double buffering.
Uses compute_done events to ensure safe buffer reuse:
- Before loading to slot X, wait for previous compute on slot X to finish
- Before computing on slot X, wait for load to slot X to finish
Timeline with 2 slots (A, B):
┌──────────────┐
│ Load B0→A │
└──────────────┘
┌──────────────┐ ┌──────────────┐
│ Load B1→B │ │ Load B2→A │ ...
└──────────────┘ └──────────────┘
↘ ↘
┌──────────────┐ ┌──────────────┐
│ Compute(A) │ │ Compute(B) │ ...
└──────────────┘ └──────────────┘
The load_to_slot_layer internally waits for compute_done[slot] before
starting the transfer, ensuring no data race.
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
num_blocks = len(cpu_block_table)
if num_blocks == 0:
return None, None
pipeline_depth = len(load_slots)
if pipeline_depth == 0:
return None, None
o_acc, lse_acc = None, None
if pipeline_depth == 1:
# Only 1 slot available, cannot pipeline - use synchronous mode
# IMPORTANT: Must use compute_stream to match synchronization in
# load_to_slot_layer (waits for compute_done) and wait_slot_layer
slot = load_slots[0]
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
cpu_block_id = cpu_block_table[block_idx]
offload_engine.load_to_slot_layer(slot, self.layer_id, cpu_block_id)
offload_engine.wait_slot_layer(slot)
with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode:
offload_engine._call_debug_hooks(slot, self.layer_id, cpu_block_id)
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
# Record compute done so next load can safely reuse this slot
offload_engine.record_slot_compute_done(slot)
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
return o_acc, lse_acc
# N-way pipeline: use ALL available slots for maximum overlap
# Pipeline depth = num_slots - 1 (num_slots blocks in flight)
num_slots = len(load_slots)
# Phase 1: Pre-load up to num_slots blocks to fill the pipeline
# This starts all transfers in parallel, utilizing full PCIe bandwidth
num_preload = min(num_slots, num_blocks)
for i in range(num_preload):
offload_engine.load_to_slot_layer(load_slots[i], self.layer_id, cpu_block_table[i])
# Phase 2: Main loop - compute and immediately reuse slot for next transfer
# Use dedicated compute_stream (not default stream) to enable overlap with transfers
compute_stream = offload_engine.compute_stream
for block_idx in range(num_blocks):
torch.cuda.nvtx.range_push(f"PipelineBlock: L{self.layer_id} B{block_idx}")
# Cycle through slots: slot[block_idx % num_slots]
current_slot = load_slots[block_idx % num_slots]
cpu_block_id = cpu_block_table[block_idx]
# Wait for current slot's transfer to complete (on compute_stream)
offload_engine.wait_slot_layer(current_slot)
# Compute attention on current slot's data
# IMPORTANT: Use dedicated compute_stream to avoid implicit sync with default stream
with torch.cuda.stream(compute_stream):
# Debug: call hooks on compute_stream (synchronized with transfer)
if offload_engine.debug_mode:
offload_engine._call_debug_hooks(current_slot, self.layer_id, cpu_block_id)
torch.cuda.nvtx.range_push(f"FlashAttn: L{self.layer_id} PrevBlock{block_idx}")
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
prev_o, prev_lse = flash_attn_with_lse(
q_batched, prev_k, prev_v,
softmax_scale=self.scale,
causal=False,
)
torch.cuda.nvtx.range_pop()
# Record compute done - this allows the next transfer to safely overwrite this slot
offload_engine.record_slot_compute_done(current_slot)
# Immediately start loading the NEXT block into this slot (if more blocks remain)
# Key insight: reuse current_slot immediately after compute is done!
next_block_idx = block_idx + num_slots
if next_block_idx < num_blocks:
offload_engine.load_to_slot_layer(current_slot, self.layer_id, cpu_block_table[next_block_idx])
# Merge with accumulated (also on compute_stream for consistency)
with torch.cuda.stream(compute_stream):
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
torch.cuda.nvtx.range_pop() # PipelineBlock
return o_acc, lse_acc
def _chunked_decode_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> torch.Tensor:
"""
Compute decode attention with double-buffering using decode_load_slots.
Decode uses:
- decode_slot (slot[0]): writes new token's KV
- decode_load_slots (slots[1:]): load previous chunks from CPU
Pipeline design:
- First half of decode_load_slots: 'compute' buffer
- Second half: 'prefetch' buffer
- Double-buffer between them for async overlap
Timeline:
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│Load C0→buf0 │ │Load C1→buf1 │ │Load C2→buf0 │ ...
└─────────────┘ └─────────────┘ └─────────────┘
↘ ↘ ↘
┌─────────────┐ ┌─────────────┐ ┌─────────────┐
│ Attn(C0) │ │ Attn(C1) │ │ Attn(C2) │
└─────────────┘ └─────────────┘ └─────────────┘
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
kvcache_manager = context.kvcache_manager
seq = context.chunked_seq
# Get only PREFILLED CPU blocks (exclude the current decode block)
# The decode block's KV is still in GPU decode_slot, not yet offloaded to CPU
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if self.layer_id == 0:
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no prefilled CPU blocks available")
# Calculate valid tokens in the last block
# prefill_len = total prefilled tokens (current decode token not yet in CPU)
block_size = kvcache_manager.block_size
prefill_len = len(seq) - 1 # Exclude current decode token
last_block_valid_tokens = prefill_len % block_size
if last_block_valid_tokens == 0 and prefill_len > 0:
last_block_valid_tokens = block_size # Last block is full
# Apply sparse policy if enabled
if kvcache_manager.sparse_policy is not None:
policy_ctx = PolicyContext(
query_chunk_idx=0,
num_query_chunks=1,
layer_id=self.layer_id,
query=q_batched, # Decode provides query for query-aware selection
is_prefill=False,
block_size=kvcache_manager.block_size,
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
)
cpu_block_table = kvcache_manager.sparse_policy.select_blocks(
cpu_block_table, policy_ctx
)
offload_engine = kvcache_manager.offload_engine
compute_stream = offload_engine.compute_stream
# Chunk size = capacity of each double buffer region (compute/prefetch)
# Each region uses half of decode_load_slots
chunk_size = max(1, len(offload_engine.decode_load_slots) // 2)
num_chunks = (len(cpu_block_table) + chunk_size - 1) // chunk_size
# Check if double buffering is possible (need at least 2 separate regions)
# With only 1 load slot, compute and prefetch regions overlap -> can't double buffer
can_double_buffer = len(offload_engine.decode_load_slots) >= 2
o_acc = None
lse_acc = None
# Double buffering state: True = use Compute region, False = use Prefetch region
use_compute = True
# Pre-load first chunk to Compute region (async)
first_chunk_ids = cpu_block_table[:min(chunk_size, len(cpu_block_table))]
offload_engine.load_to_compute_layer(self.layer_id, first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * chunk_size
end = min(start + chunk_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Wait for current buffer to be ready on compute_stream
# The load runs on transfer_stream_main, compute runs on compute_stream
compute_stream.wait_stream(offload_engine.transfer_stream_main)
# All computation on explicit compute_stream
with torch.cuda.stream(compute_stream):
# Get KV from current buffer FIRST, before prefetching overwrites it
if use_compute:
k_chunk, v_chunk = offload_engine.get_kv_for_compute(num_blocks_in_chunk)
else:
k_chunk, v_chunk = offload_engine.get_kv_for_prefetch(num_blocks_in_chunk)
# Handle partial last block: slice to only include valid tokens
# This is critical because the rest of the block contains stale data
is_last_chunk = (end == len(cpu_block_table))
if is_last_chunk and last_block_valid_tokens < block_size:
# Calculate total valid tokens in this chunk
# All blocks except the last are full, last block has last_block_valid_tokens
full_blocks = num_blocks_in_chunk - 1
valid_tokens = full_blocks * block_size + last_block_valid_tokens
# Slice KV: [batch, seqlen, heads, dim] -> [batch, valid_tokens, heads, dim]
k_chunk = k_chunk[:, :valid_tokens, :, :]
v_chunk = v_chunk[:, :valid_tokens, :, :]
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Trigger async prefetch/load of next chunk to the OTHER buffer
# This happens AFTER attention completes, so the data is no longer needed
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + chunk_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if can_double_buffer:
if use_compute:
# Current in Compute, prefetch next to Prefetch region
offload_engine.load_to_prefetch_layer(self.layer_id, next_chunk_ids)
else:
# Current in Prefetch, prefetch next to Compute region
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
else:
# Sync fallback: load next chunk to same slot (always compute region)
offload_engine.load_to_compute_layer(self.layer_id, next_chunk_ids)
# Swap buffers for next iteration (only matters if can_double_buffer)
use_compute = not use_compute
# Now attend to Decode region (contains accumulated decode tokens)
pos_in_block = context.decode_pos_in_block
start_pos = context.decode_start_pos_in_block
num_accumulated = pos_in_block - start_pos + 1
# IMPORTANT: Sync compute_stream with default stream before reading decode_slot
# store_kvcache writes to decode_slot on default stream (before entering this function)
# We need to ensure that write is complete before reading on compute_stream
compute_stream.wait_stream(torch.cuda.default_stream())
with torch.cuda.stream(compute_stream):
if num_accumulated > 0:
# GPU cache has no layer dimension
decode_k = offload_engine.k_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_v = offload_engine.v_cache_gpu[offload_engine.decode_slot, start_pos:pos_in_block+1]
decode_k = decode_k.unsqueeze(0)
decode_v = decode_v.unsqueeze(0)
decode_o, decode_lse = flash_attn_with_lse(
q_batched, decode_k, decode_v,
softmax_scale=self.scale,
causal=False,
)
if o_acc is None:
o_acc = decode_o
else:
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")
# Sync back to default stream before returning
# Caller expects result to be ready on default stream
torch.cuda.default_stream().wait_stream(compute_stream)
return o_acc