- Add compute_prefill() and compute_decode() GPU-only methods to SparsePolicy base class - Implement GPU-only methods in FullAttentionPolicy using flash_attn - Add sparse_policy parameter to GPUOnlyManager - Update create_kvcache_manager() to create FullAttentionPolicy for GPU-only mode - Route GPU-only attention through sparse_policy in attention.py - Pass kvcache_manager to context for policy access - Add --enable-policy flag to bench.py for testing - Handle warmup phase when kvcache_manager is not yet allocated This allows GPU-only mode to use the same policy architecture as CPU offload mode, enabling future sparse attention implementations (Quest, XAttention) in GPU-only mode. Performance verified: ~4890 tok/s (unchanged from baseline) Generated with [Claude Code](https://claude.ai/code) via [Happy](https://happy.engineering) Co-Authored-By: Claude <noreply@anthropic.com> Co-Authored-By: Happy <yesreply@happy.engineering>
340 lines
14 KiB
Python
340 lines
14 KiB
Python
import logging
|
||
import torch
|
||
import torch.cuda.nvtx
|
||
from torch import nn
|
||
|
||
from flash_attn.flash_attn_interface import flash_attn_varlen_func, flash_attn_with_kvcache
|
||
from nanovllm.utils.context import get_context
|
||
from nanovllm.kvcache.sparse.policy import PolicyContext
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def store_kvcache(
|
||
key: torch.Tensor,
|
||
value: torch.Tensor,
|
||
k_cache: torch.Tensor,
|
||
v_cache: torch.Tensor,
|
||
slot_mapping: torch.Tensor,
|
||
):
|
||
"""
|
||
Store key/value tensors into KV cache using slot mapping.
|
||
|
||
This is a pure PyTorch implementation replacing the previous Triton kernel.
|
||
Uses index_copy_ for efficient in-place scatter operation.
|
||
|
||
Args:
|
||
key: [N, num_kv_heads, head_dim]
|
||
value: [N, num_kv_heads, head_dim]
|
||
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] or similar
|
||
v_cache: same shape as k_cache
|
||
slot_mapping: [N] with values as flat indices, -1 means skip
|
||
"""
|
||
is_capturing = torch.cuda.is_current_stream_capturing()
|
||
|
||
if is_capturing:
|
||
# During CUDA graph capture, assume all slots are valid.
|
||
# CUDA graphs don't support data-dependent operations like boolean indexing.
|
||
# This is safe because decode (captured) always has valid slots.
|
||
valid_slots = slot_mapping
|
||
valid_keys = key
|
||
valid_values = value
|
||
else:
|
||
# Normal execution: filter out invalid slots (slot == -1)
|
||
valid_mask = slot_mapping >= 0
|
||
if not valid_mask.any():
|
||
return
|
||
valid_slots = slot_mapping[valid_mask]
|
||
valid_keys = key[valid_mask] # [M, num_kv_heads, head_dim]
|
||
valid_values = value[valid_mask]
|
||
|
||
# Flatten cache and KV for scatter operation
|
||
# Cache is viewed as [total_slots, D] where D = num_kv_heads * head_dim
|
||
N, num_kv_heads, head_dim = key.shape
|
||
D = num_kv_heads * head_dim
|
||
total_slots = k_cache.numel() // D
|
||
|
||
k_cache_flat = k_cache.view(total_slots, D)
|
||
v_cache_flat = v_cache.view(total_slots, D)
|
||
valid_keys_flat = valid_keys.reshape(-1, D)
|
||
valid_values_flat = valid_values.reshape(-1, D)
|
||
|
||
# In-place scatter using index_copy_
|
||
# 即使 valid_slots 为空张量,index_copy_ 也是安全的(不会修改数据)。
|
||
k_cache_flat.index_copy_(0, valid_slots.long(), valid_keys_flat)
|
||
v_cache_flat.index_copy_(0, valid_slots.long(), valid_values_flat)
|
||
|
||
|
||
class Attention(nn.Module):
|
||
|
||
def __init__(
|
||
self,
|
||
num_heads,
|
||
head_dim,
|
||
scale,
|
||
num_kv_heads,
|
||
):
|
||
super().__init__()
|
||
self.num_heads = num_heads
|
||
self.head_dim = head_dim
|
||
self.scale = scale
|
||
self.num_kv_heads = num_kv_heads
|
||
self.k_cache = self.v_cache = torch.tensor([])
|
||
# Layer ID set by model_runner after model creation
|
||
self.layer_id: int = -1
|
||
|
||
def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
|
||
context = get_context()
|
||
k_cache, v_cache = self.k_cache, self.v_cache
|
||
|
||
# Determine if we're in chunked offload mode
|
||
is_chunked_offload = (
|
||
context.is_chunked_prefill and
|
||
hasattr(context, 'kvcache_manager') and
|
||
context.kvcache_manager is not None and
|
||
hasattr(context.kvcache_manager, 'offload_engine')
|
||
)
|
||
|
||
#! Ensure synchronization before accessing k_cache/v_cache
|
||
# torch.cuda.synchronize()
|
||
#! =======================================================
|
||
|
||
if is_chunked_offload and context.is_prefill:
|
||
# Chunked prefill mode: write KV to per-layer prefill buffer (not GPU slot)
|
||
# This enables fully async offloads since each layer has its own buffer.
|
||
offload_engine = context.kvcache_manager.offload_engine
|
||
compute_stream = offload_engine.compute_stream
|
||
chunk_idx = context.current_chunk_idx if hasattr(context, 'current_chunk_idx') else -1
|
||
|
||
# Wait for default stream to ensure slot_mapping tensor transfer is complete
|
||
compute_stream.wait_stream(torch.cuda.default_stream())
|
||
|
||
with torch.cuda.stream(compute_stream):
|
||
# Write KV to per-layer prefill buffer via offload_engine
|
||
# k, v shape: [num_tokens, kv_heads, head_dim]
|
||
#! GPU 2 GPU
|
||
offload_engine.write_to_prefill_buffer(self.layer_id, k, v, chunk_idx=chunk_idx)
|
||
elif is_chunked_offload:
|
||
# Chunked decode mode: write KV to per-layer decode buffer via offload_engine
|
||
# KV will be written to decode buffer in the decode branch below
|
||
# No store_kvcache needed - all KV management goes through offload_engine
|
||
pass
|
||
else:
|
||
# Normal mode: store on default stream
|
||
if k_cache.numel() and v_cache.numel():
|
||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||
|
||
# Get sparse_policy from kvcache_manager (required, never None after warmup)
|
||
# During warmup, kvcache_manager is not yet allocated
|
||
if context.kvcache_manager is None:
|
||
# Warmup phase: use flash_attn directly
|
||
if context.is_prefill:
|
||
return flash_attn_varlen_func(
|
||
q, k, v,
|
||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||
softmax_scale=self.scale, causal=True,
|
||
)
|
||
else:
|
||
return flash_attn_with_kvcache(
|
||
q.unsqueeze(1), k_cache, v_cache,
|
||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||
softmax_scale=self.scale, causal=True,
|
||
)
|
||
sparse_policy = context.kvcache_manager.sparse_policy
|
||
assert sparse_policy is not None, "sparse_policy must not be None"
|
||
|
||
if context.is_prefill:
|
||
if context.is_chunked_prefill:
|
||
# Chunked prefill: merge attention from previous KV (CPU offload mode)
|
||
o = self._chunked_prefill_attention(q, k, v, context)
|
||
else:
|
||
# GPU-only mode: use policy for attention
|
||
# Use paged attention if block_tables provided, else use k, v directly
|
||
if context.block_tables is not None:
|
||
k_for_attn, v_for_attn = k_cache, v_cache
|
||
else:
|
||
k_for_attn, v_for_attn = k, v
|
||
o = sparse_policy.compute_prefill(
|
||
q, k_for_attn, v_for_attn,
|
||
context.cu_seqlens_q, context.cu_seqlens_k,
|
||
context.max_seqlen_q, context.max_seqlen_k,
|
||
self.scale, self.layer_id,
|
||
context.block_tables,
|
||
)
|
||
else: # decode
|
||
if context.is_chunked_prefill:
|
||
# Chunked decode: need to load all KV from CPU+GPU (CPU offload mode)
|
||
# Store current decode token to per-layer decode buffer
|
||
# This is needed because GPU cache has no layer dimension,
|
||
# so all layers would overwrite each other in decode_slot.
|
||
kvcache_manager = context.kvcache_manager
|
||
offload_engine = kvcache_manager.offload_engine
|
||
pos_in_block = context.decode_pos_in_block
|
||
# k, v shape: [1, kv_heads, head_dim]
|
||
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
|
||
o = self._chunked_decode_attention(q, k, v, context)
|
||
else:
|
||
# GPU-only mode: use policy for attention
|
||
o = sparse_policy.compute_decode(
|
||
q, k_cache, v_cache,
|
||
context.context_lens, self.scale, self.layer_id,
|
||
context.block_tables,
|
||
)
|
||
return o
|
||
|
||
def _chunked_prefill_attention(
|
||
self,
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
context,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Compute attention with per-layer prefill buffer for async offload.
|
||
|
||
Simplified design:
|
||
- All computation logic is delegated to sparse_policy.compute_chunked_prefill()
|
||
- This method only handles async offload after computation
|
||
|
||
The policy handles:
|
||
1. Loading historical blocks from CPU
|
||
2. Computing attention against historical KV (no causal mask)
|
||
3. Computing attention against current KV from prefill buffer (causal)
|
||
4. Merging all results
|
||
"""
|
||
current_chunk_idx = context.current_chunk_idx
|
||
torch.cuda.nvtx.range_push(f"ChunkedPrefill: L{self.layer_id} Chunk{current_chunk_idx}")
|
||
|
||
num_tokens = k.shape[0]
|
||
|
||
kvcache_manager = context.kvcache_manager
|
||
seq = context.chunked_seq if hasattr(context, 'chunked_seq') else None
|
||
offload_engine = kvcache_manager.offload_engine if kvcache_manager is not None else None
|
||
|
||
# Get sparse policy - required for chunked prefill
|
||
sparse_policy = kvcache_manager.sparse_policy
|
||
if sparse_policy is None:
|
||
raise RuntimeError("sparse_policy is required for chunked prefill")
|
||
|
||
# Step 1: Get historical CPU blocks
|
||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||
|
||
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_prefill)
|
||
selected_blocks = []
|
||
if cpu_block_table:
|
||
num_chunks = current_chunk_idx + 1
|
||
policy_ctx = PolicyContext(
|
||
query_chunk_idx=current_chunk_idx,
|
||
num_query_chunks=num_chunks,
|
||
layer_id=self.layer_id,
|
||
query=q, # Pass query for sparse policies that need it
|
||
is_prefill=True,
|
||
block_size=kvcache_manager.block_size,
|
||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||
)
|
||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||
logger.debug(f"[DEBUG] select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||
|
||
# [DEBUG] Verify execution path
|
||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_prefill, "
|
||
f"policy={sparse_policy}, layer={self.layer_id}, chunk={current_chunk_idx}")
|
||
|
||
# Delegate computation to policy with pre-selected blocks
|
||
final_o = sparse_policy.compute_chunked_prefill(
|
||
q, k, v,
|
||
self.layer_id,
|
||
self.scale,
|
||
offload_engine,
|
||
kvcache_manager,
|
||
current_chunk_idx,
|
||
seq,
|
||
num_tokens,
|
||
selected_blocks,
|
||
)
|
||
|
||
torch.cuda.nvtx.range_pop() # ChunkedPrefill
|
||
|
||
# Per-layer ASYNC offload: offload prefill buffer to CPU
|
||
# No waiting required! Each layer has its own buffer and stream.
|
||
if offload_engine is not None and seq is not None:
|
||
cpu_block_ids, _ = kvcache_manager.get_all_cpu_blocks(seq)
|
||
if current_chunk_idx < len(cpu_block_ids):
|
||
cpu_block_id = cpu_block_ids[current_chunk_idx]
|
||
# Async offload - no waiting, fully parallel across layers
|
||
offload_engine.offload_prefill_buffer_async(
|
||
self.layer_id, cpu_block_id, num_tokens
|
||
)
|
||
|
||
return final_o
|
||
|
||
def _chunked_decode_attention(
|
||
self,
|
||
q: torch.Tensor,
|
||
k: torch.Tensor,
|
||
v: torch.Tensor,
|
||
context,
|
||
) -> torch.Tensor:
|
||
"""
|
||
Compute decode attention by delegating to sparse policy.
|
||
|
||
Simplified design:
|
||
- All computation logic is delegated to sparse_policy.compute_chunked_decode()
|
||
- This method only validates the policy and delegates
|
||
|
||
The policy handles:
|
||
1. Loading prefilled blocks from CPU via pipeline
|
||
2. Computing attention against prefilled KV
|
||
3. Reading accumulated decode tokens from decode buffer
|
||
4. Merging all results
|
||
"""
|
||
kvcache_manager = context.kvcache_manager
|
||
seq = context.chunked_seq
|
||
offload_engine = kvcache_manager.offload_engine
|
||
|
||
# Get sparse policy - required for chunked decode
|
||
sparse_policy = kvcache_manager.sparse_policy
|
||
if sparse_policy is None:
|
||
raise RuntimeError("sparse_policy is required for chunked decode")
|
||
|
||
# Check if policy supports decode phase
|
||
# If not, fallback to FullAttentionPolicy (e.g., XAttentionBSAPolicy only supports prefill)
|
||
if not sparse_policy.supports_decode:
|
||
from nanovllm.kvcache.sparse import FullAttentionPolicy
|
||
sparse_policy = FullAttentionPolicy()
|
||
logger.debug(f"[DEBUG] {kvcache_manager.sparse_policy} doesn't support decode, "
|
||
f"falling back to FullAttentionPolicy")
|
||
|
||
# Step 1: Get prefilled CPU blocks
|
||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||
|
||
# Step 2: Apply select_blocks to filter blocks (before calling compute_chunked_decode)
|
||
selected_blocks = []
|
||
if cpu_block_table:
|
||
policy_ctx = PolicyContext(
|
||
query_chunk_idx=0,
|
||
num_query_chunks=1,
|
||
layer_id=self.layer_id,
|
||
query=q, # Pass query for sparse policies that need it
|
||
is_prefill=False,
|
||
block_size=kvcache_manager.block_size,
|
||
total_kv_len=len(cpu_block_table) * kvcache_manager.block_size,
|
||
)
|
||
selected_blocks = sparse_policy.select_blocks(cpu_block_table, offload_engine, policy_ctx)
|
||
logger.debug(f"[DEBUG] decode select_blocks: {len(cpu_block_table)} -> {len(selected_blocks)} blocks")
|
||
|
||
# [DEBUG] Verify execution path
|
||
logger.debug(f"[DEBUG] Calling sparse_policy.compute_chunked_decode, "
|
||
f"policy={sparse_policy}, layer={self.layer_id}")
|
||
|
||
# Delegate computation to policy with pre-selected blocks
|
||
return sparse_policy.compute_chunked_decode(
|
||
q,
|
||
self.layer_id,
|
||
self.scale,
|
||
offload_engine,
|
||
kvcache_manager,
|
||
seq,
|
||
selected_blocks,
|
||
)
|