310 lines
13 KiB
Python
310 lines
13 KiB
Python
import logging
|
||
import torch
|
||
from torch import nn
|
||
import triton
|
||
import triton.language as tl
|
||
|
||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||
from nanovllm.utils.context import get_context
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@triton.jit
|
||
def store_kvcache_kernel(
|
||
key_ptr,
|
||
key_stride,
|
||
value_ptr,
|
||
value_stride,
|
||
k_cache_ptr,
|
||
v_cache_ptr,
|
||
slot_mapping_ptr,
|
||
D: tl.constexpr,
|
||
):
|
||
idx = tl.program_id(0)
|
||
slot = tl.load(slot_mapping_ptr + idx)
|
||
if slot == -1: return
|
||
key_offsets = idx * key_stride + tl.arange(0, D)
|
||
value_offsets = idx * value_stride + tl.arange(0, D)
|
||
key = tl.load(key_ptr + key_offsets)
|
||
value = tl.load(value_ptr + value_offsets)
|
||
cache_offsets = slot * D + tl.arange(0, D)
|
||
tl.store(k_cache_ptr + cache_offsets, key)
|
||
tl.store(v_cache_ptr + cache_offsets, value)
|
||
|
||
|
||
def store_kvcache(key: torch.Tensor, value: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, slot_mapping: torch.Tensor):
|
||
N, num_heads, head_dim = key.shape
|
||
D = num_heads * head_dim
|
||
assert key.stride(-1) == 1 and value.stride(-1) == 1
|
||
assert key.stride(1) == head_dim and value.stride(1) == head_dim
|
||
assert k_cache.stride(1) == D and v_cache.stride(1) == D
|
||
assert slot_mapping.numel() == N
|
||
store_kvcache_kernel[(N,)](key, key.stride(0), value, value.stride(0), k_cache, v_cache, slot_mapping, D)
|
||
|
||
|
||
class Attention(nn.Module):
|
||
|
||
def __init__(
|
||
self,
|
||
num_heads,
|
||
head_dim,
|
||
scale,
|
||
num_kv_heads,
|
||
):
|
||
super().__init__()
|
||
self.num_heads = num_heads
|
||
self.head_dim = head_dim
|
||
self.scale = scale
|
||
self.num_kv_heads = num_kv_heads
|
||
self.k_cache = self.v_cache = torch.tensor([])
|
||
# Layer ID set by model_runner after model creation
|
||
self.layer_id: int = -1
|
||
|
||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||
context = get_context()
|
||
k_cache, v_cache = self.k_cache, self.v_cache
|
||
if k_cache.numel() and v_cache.numel():
|
||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||
|
||
if context.is_prefill:
|
||
if context.is_chunked_prefill:
|
||
# Chunked prefill: merge attention from previous KV
|
||
o = self._chunked_prefill_attention(q, k, v, context)
|
||
elif context.block_tables is not None: # prefix cache
|
||
k, v = k_cache, v_cache
|
||
o = flash_attn_varlen_func(q, k, v,
|
||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||
else:
|
||
o = flash_attn_varlen_func(q, k, v,
|
||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||
else: # decode
|
||
if context.is_chunked_prefill:
|
||
# Chunked decode: need to load all KV from CPU+GPU
|
||
o = self._chunked_decode_attention(q, k, v, context)
|
||
else:
|
||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||
softmax_scale=self.scale, causal=True)
|
||
return o
|
||
|
||
def _chunked_prefill_attention(
|
||
self,
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
context,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Compute attention with 三区域 GPU buffer for chunked prefill.
|
||
|
||
For chunked prefill:
|
||
1. Load previous KV from CPU using Compute/Prefetch区 (if any previous chunks)
|
||
2. Compute attention against previous KV chunks (no causal mask)
|
||
3. Compute attention against current chunk's KV (causal)
|
||
4. Merge all results using online softmax
|
||
|
||
三区域设计保证:当前chunk的KV在Compute区,previous KV从CPU加载到Prefetch区,
|
||
不会发生写入和加载区域重叠的问题。
|
||
"""
|
||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||
|
||
# q, k, v shape: [total_tokens, num_heads, head_dim]
|
||
# Reshape for flash attention: [batch, seq, heads, dim]
|
||
q_batched = q.unsqueeze(0) # [1, total_tokens, heads, dim]
|
||
k_batched = k.unsqueeze(0)
|
||
v_batched = v.unsqueeze(0)
|
||
|
||
o_acc = None
|
||
lse_acc = None
|
||
|
||
# Load previous KV from CPU using Compute/Prefetch区
|
||
# Note: context.offload_engine is actually HybridKVCacheManager
|
||
kvcache_manager = context.offload_engine
|
||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||
|
||
if kvcache_manager is not None and seq is not None and self.layer_id >= 0:
|
||
# Get prefilled CPU blocks (blocks already written in previous chunks)
|
||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||
|
||
if cpu_block_table:
|
||
offload_engine = kvcache_manager.offload_engine
|
||
# 使用 Prefetch区 来加载 previous KV(不会与当前 Compute区 冲突)
|
||
prefetch_size = offload_engine.num_prefetch_blocks
|
||
num_chunks = (len(cpu_block_table) + prefetch_size - 1) // prefetch_size
|
||
use_compute = True # 交替使用 Compute区 和 Prefetch区
|
||
|
||
# 首先将 previous KV 加载到 Prefetch区
|
||
# Only layer 0 triggers the load (loads ALL layers at once)
|
||
first_chunk_end = min(prefetch_size, len(cpu_block_table))
|
||
first_chunk_ids = cpu_block_table[:first_chunk_end]
|
||
if self.layer_id == 0:
|
||
offload_engine.load_to_prefetch(first_chunk_ids)
|
||
|
||
for chunk_idx in range(num_chunks):
|
||
start = chunk_idx * prefetch_size
|
||
end = min(start + prefetch_size, len(cpu_block_table))
|
||
num_blocks_in_chunk = end - start
|
||
|
||
# Prefetch next chunk to other buffer (if exists)
|
||
# Only layer 0 triggers the load
|
||
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
||
next_start = end
|
||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||
if use_compute:
|
||
# 当前在 Prefetch区,下一个加载到 Compute区(如果有空间)
|
||
# 注意:Compute区 此时已写入当前chunk的KV,不能覆盖
|
||
# 所以这里我们使用简单的同步策略:等待当前完成后再加载
|
||
pass # 简化版本:不进行双缓冲,只用 Prefetch区
|
||
else:
|
||
offload_engine.load_to_prefetch(next_chunk_ids)
|
||
|
||
# Wait for Prefetch区 and get KV
|
||
offload_engine.wait_prefetch()
|
||
prev_k, prev_v = offload_engine.get_kv_for_prefetch(
|
||
self.layer_id, num_blocks_in_chunk
|
||
)
|
||
|
||
# Compute attention against this chunk (no causal mask)
|
||
prev_o, prev_lse = flash_attn_with_lse(
|
||
q_batched,
|
||
prev_k,
|
||
prev_v,
|
||
softmax_scale=self.scale,
|
||
causal=False,
|
||
)
|
||
|
||
# Merge with accumulated
|
||
if o_acc is None:
|
||
o_acc, lse_acc = prev_o, prev_lse
|
||
else:
|
||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||
|
||
# Load next chunk to Prefetch区 (if exists)
|
||
if chunk_idx + 1 < num_chunks and self.layer_id == 0:
|
||
next_start = end
|
||
next_end = min(next_start + prefetch_size, len(cpu_block_table))
|
||
next_chunk_ids = cpu_block_table[next_start:next_end]
|
||
offload_engine.load_to_prefetch(next_chunk_ids)
|
||
|
||
# Compute attention against current chunk's KV (with causal mask)
|
||
current_o, current_lse = flash_attn_with_lse(
|
||
q_batched,
|
||
k_batched,
|
||
v_batched,
|
||
softmax_scale=self.scale,
|
||
causal=True,
|
||
)
|
||
|
||
# Merge with accumulated
|
||
if o_acc is None:
|
||
final_o = current_o
|
||
else:
|
||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||
|
||
# Remove batch dimension: [1, total_tokens, heads, dim] -> [total_tokens, heads, dim]
|
||
return final_o.squeeze(0)
|
||
|
||
def _chunked_decode_attention(
|
||
self,
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
context,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Compute decode attention with 三区域 GPU buffer.
|
||
|
||
All KV is stored on CPU. Uses Compute区 buffer on GPU:
|
||
1. Load chunk to Compute区
|
||
2. Compute attention
|
||
3. Repeat for all chunks
|
||
4. Finally, attend to Decode区 (slot 0) which contains the new token's KV
|
||
5. Merge all attention outputs using online softmax (LSE)
|
||
|
||
关键:新token的KV在Decode区(slot 0),不会被Compute区的加载覆盖。
|
||
"""
|
||
from nanovllm.kvcache.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||
|
||
# q shape: [batch_size, num_heads, head_dim] (single decode token per sequence)
|
||
# Need: [batch, seqlen, heads, dim]
|
||
q_batched = q.unsqueeze(1) # [batch, 1, heads, dim]
|
||
|
||
# Note: context.offload_engine is actually HybridKVCacheManager
|
||
kvcache_manager = context.offload_engine
|
||
seq = context.chunked_seq
|
||
|
||
# Get all CPU blocks for this sequence
|
||
cpu_block_table, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||
if self.layer_id == 0:
|
||
logger.debug(f"Decode attention: cpu_block_table={cpu_block_table}, seq.block_table={list(seq.block_table)}")
|
||
if not cpu_block_table:
|
||
raise RuntimeError("Chunked decode attention failed: no CPU blocks available")
|
||
|
||
# Get the actual offload_engine for 三区域 operations
|
||
offload_engine = kvcache_manager.offload_engine
|
||
|
||
# Calculate chunk info using Compute区
|
||
compute_size = offload_engine.num_compute_blocks
|
||
num_chunks = (len(cpu_block_table) + compute_size - 1) // compute_size
|
||
|
||
o_acc = None
|
||
lse_acc = None
|
||
|
||
for chunk_idx in range(num_chunks):
|
||
start = chunk_idx * compute_size
|
||
end = min(start + compute_size, len(cpu_block_table))
|
||
num_blocks_in_chunk = end - start
|
||
chunk_ids = cpu_block_table[start:end]
|
||
|
||
# Load this chunk to Compute区
|
||
# Only layer 0 triggers the load (loads ALL layers at once)
|
||
if self.layer_id == 0:
|
||
offload_engine.load_to_compute(chunk_ids)
|
||
|
||
# Wait for Compute区 to be ready and get KV
|
||
offload_engine.wait_compute()
|
||
k_chunk, v_chunk = offload_engine.get_kv_for_compute(
|
||
self.layer_id, num_blocks_in_chunk
|
||
)
|
||
|
||
# Compute attention for this chunk
|
||
o_chunk, lse_chunk = flash_attn_with_lse(
|
||
q_batched, k_chunk, v_chunk,
|
||
softmax_scale=self.scale,
|
||
causal=False,
|
||
)
|
||
|
||
# Merge with accumulated
|
||
if o_acc is None:
|
||
o_acc, lse_acc = o_chunk, lse_chunk
|
||
else:
|
||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, o_chunk, lse_chunk)
|
||
|
||
# Now attend to Decode区 (contains the new token's KV)
|
||
# This is the token being decoded - only 1 token at position pos_in_block
|
||
pos_in_block = context.decode_pos_in_block
|
||
decode_k, decode_v = offload_engine.get_kv_for_decode_slot(self.layer_id, pos_in_block)
|
||
decode_o, decode_lse = flash_attn_with_lse(
|
||
q_batched, decode_k, decode_v,
|
||
softmax_scale=self.scale,
|
||
causal=False,
|
||
)
|
||
|
||
# Merge with accumulated
|
||
if o_acc is None:
|
||
o_acc = decode_o
|
||
else:
|
||
o_acc, _ = merge_attention_outputs(o_acc, lse_acc, decode_o, decode_lse)
|
||
|
||
if o_acc is None:
|
||
raise RuntimeError("Chunked decode attention failed: no KV available")
|
||
|
||
# Output shape: [batch, 1, heads, dim] (same as normal decode)
|
||
return o_acc
|