Files
nano-vllm/nanovllm/layers/attention.py

297 lines
12 KiB
Python

import torch
from torch import nn
import triton
import triton.language as tl
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from nanovllm.utils.context import get_context
@triton.jit
def store_kvcache_kernel(
key_ptr,
key_stride,
value_ptr,
value_stride,
k_cache_ptr,
v_cache_ptr,
slot_mapping_ptr,
D: tl.constexpr,
):
idx = tl.program_id(0)
slot = tl.load(slot_mapping_ptr + idx)
if slot == -1: return
key_offsets = idx * key_stride + tl.arange(0, D)
value_offsets = idx * value_stride + tl.arange(0, D)
key = tl.load(key_ptr + key_offsets)
value = tl.load(value_ptr + value_offsets)
cache_offsets = slot * D + tl.arange(0, D)
tl.store(k_cache_ptr + cache_offsets, key)
tl.store(v_cache_ptr + cache_offsets, value)
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
N, num_heads, head_dim = key.shape
D = num_heads * head_dim
assert key.stride(-1) == 1 and value.stride(-1) == 1
assert key.stride(1) == head_dim and value.stride(1) == head_dim
assert k_cache.stride(1) == D and v_cache.stride(1) == D
assert slot_mapping.numel() == N
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
class Attention(nn.Module):
def __init__(
self,
num_heads,
head_dim,
scale,
num_kv_heads,
):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.scale = scale
self.num_kv_heads = num_kv_heads
self.k_cache = self.v_cache = torch.tensor([])
# Layer ID set by model_runner after model creation
self.layer_id: int = -1
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
context = get_context()
k_cache, v_cache = self.k_cache, self.v_cache
if k_cache.numel() and v_cache.numel():
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
if context.is_prefill:
if context.is_chunked_prefill:
# Chunked prefill: merge attention from previous KV
o = self._chunked_prefill_attention(q, k, v, context)
elif context.block_tables is not None: # prefix cache
k, v = k_cache, v_cache
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else:
o = flash_attn_varlen_func(q, k, v,
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
else: # decode
if context.is_chunked_prefill:
# Chunked decode: need to load all KV from CPU+GPU
o = self._chunked_decode_attention(q, k, v, context)
else:
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
cache_seqlens=context.context_lens, block_table=context.block_tables,
softmax_scale=self.scale, causal=True)
return o
def _chunked_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> torch.Tensor:
"""
Compute attention with Ping-Pong dual buffer for chunked prefill.
For chunked prefill:
1. Load previous KV from CPU using Ping-Pong (if any previous chunks)
2. Compute attention against previous KV chunks (no causal mask)
3. Compute attention against current chunk's KV (causal)
4. Merge all results using online softmax
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# q, k, v shape: [total_tokens, num_heads, head_dim]
# Reshape for flash attention: [batch, seq, heads, dim]
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
k_batched = k.unsqueeze(0)
v_batched = v.unsqueeze(0)
o_acc = None
lse_acc = None
# Load previous KV from CPU using Ping-Pong
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
# Get prefilled CPU blocks (blocks already written in previous chunks)
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
if cpu_block_table:
offload_engine = kvcache_manager.offload_engine
ping_size = offload_engine.ping_size
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
current_buffer = "ping"
# Prefetch first chunk to Ping buffer
first_chunk_end = min(ping_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
offload_engine.load_to_ping(first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * ping_size
end = min(start + ping_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Prefetch next chunk to OTHER buffer
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + ping_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if current_buffer == "ping":
offload_engine.load_to_pong(next_chunk_ids)
else:
offload_engine.load_to_ping(next_chunk_ids)
# Wait for current buffer and get KV
if current_buffer == "ping":
offload_engine.wait_ping()
prev_k, prev_v = offload_engine.get_kv_for_ping_slots(
self.layer_id, num_blocks_in_chunk
)
else:
offload_engine.wait_pong()
prev_k, prev_v = offload_engine.get_kv_for_pong_slots(
self.layer_id, num_blocks_in_chunk
)
# Compute attention against this chunk (no causal mask)
prev_o, prev_lse = flash_attn_with_lse(
q_batched,
prev_k,
prev_v,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = prev_o, prev_lse
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
# Switch buffer
current_buffer = "pong" if current_buffer == "ping" else "ping"
# Compute attention against current chunk's KV (with causal mask)
current_o, current_lse = flash_attn_with_lse(
q_batched,
k_batched,
v_batched,
softmax_scale=self.scale,
causal=True,
)
# Merge with accumulated
if o_acc is None:
final_o = current_o
else:
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
return final_o.squeeze(0)
def _chunked_decode_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> torch.Tensor:
"""
Compute decode attention with Ping-Pong dual buffer.
All KV is stored on CPU. Uses Ping-Pong buffers on GPU:
1. Load first chunk to Ping buffer
2. While computing on current buffer, prefetch next chunk to other buffer
3. Alternate between Ping and Pong buffers
4. Merge attention outputs using online softmax (LSE)
"""
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
# Need: [batch, seqlen, heads, dim]
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
# Note: context.offload_engine is actually HybridKVCacheManager
kvcache_manager = context.offload_engine
seq = context.chunked_seq
# Get all CPU blocks for this sequence
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
if not cpu_block_table:
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
# Get the actual offload_engine for Ping-Pong operations
offload_engine = kvcache_manager.offload_engine
# Calculate chunk info
ping_size = offload_engine.ping_size
num_chunks = (len(cpu_block_table) + ping_size - 1) // ping_size
o_acc = None
lse_acc = None
current_buffer = "ping"
# Prefetch first chunk to Ping buffer (loads all layers at once)
first_chunk_end = min(ping_size, len(cpu_block_table))
first_chunk_ids = cpu_block_table[:first_chunk_end]
offload_engine.load_to_ping(first_chunk_ids)
for chunk_idx in range(num_chunks):
start = chunk_idx * ping_size
end = min(start + ping_size, len(cpu_block_table))
num_blocks_in_chunk = end - start
# Prefetch next chunk to OTHER buffer (overlapped with current computation)
if chunk_idx + 1 < num_chunks:
next_start = end
next_end = min(next_start + ping_size, len(cpu_block_table))
next_chunk_ids = cpu_block_table[next_start:next_end]
if current_buffer == "ping":
offload_engine.load_to_pong(next_chunk_ids)
else:
offload_engine.load_to_ping(next_chunk_ids)
# Wait for current buffer to be ready and get KV
if current_buffer == "ping":
offload_engine.wait_ping()
k_chunk, v_chunk = offload_engine.get_kv_for_ping_slots(
self.layer_id, num_blocks_in_chunk
)
else:
offload_engine.wait_pong()
k_chunk, v_chunk = offload_engine.get_kv_for_pong_slots(
self.layer_id, num_blocks_in_chunk
)
# Compute attention for this chunk
o_chunk, lse_chunk = flash_attn_with_lse(
q_batched, k_chunk, v_chunk,
softmax_scale=self.scale,
causal=False,
)
# Merge with accumulated
if o_acc is None:
o_acc, lse_acc = o_chunk, lse_chunk
else:
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
# Switch buffer for next iteration
current_buffer = "pong" if current_buffer == "ping" else "ping"
if o_acc is None:
raise RuntimeError("Chunked decode attention failed: no KV available")
# Output shape: [batch, 1, heads, dim] (same as normal decode)
return o_acc