Files
nano-vllm/nanovllm/kvcache/sparse/policy.py
2026-01-22 22:20:34 +08:00

226 lines
6.8 KiB
Python

"""
Base class for attention policies in layerwise offload mode.
AttentionPolicy defines the interface for all attention computation,
including full attention and sparse attention methods like XAttention.
Key methods:
- estimate(): Compute sparse attention mask (optional, returns None for full attention)
- compute_prefill(): Compute prefill attention
- compute_decode(): Compute decode attention (default implementation provided)
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Tuple
import torch
# Import SparsePolicyType from config to avoid circular imports
from nanovllm.config import SparsePolicyType
@dataclass
class PolicyContext:
"""
Context passed to attention policy for block selection.
This dataclass contains all information needed by an attention policy
for sparse estimation and attention computation.
"""
query_chunk_idx: int
"""Index of the current query chunk (0-indexed)."""
num_query_chunks: int
"""Total number of query chunks in this prefill."""
layer_id: int
"""Current transformer layer index."""
query: Optional[torch.Tensor]
"""
Query tensor for current chunk.
Shape: [1, num_heads, head_dim] for decode, [1, seq_len, num_heads, head_dim] for prefill.
May be None if not available (e.g., some prefill scenarios).
"""
is_prefill: bool
"""True if in prefill phase, False if in decode phase."""
block_size: int = 1024
"""Number of tokens per block."""
total_kv_len: int = 0
"""Total KV sequence length so far (for reference)."""
class AttentionPolicy(ABC):
"""
Base class for attention policies in layerwise offload mode.
All attention computation goes through a policy, including both
full attention and sparse attention methods.
The policy interface is designed for layerwise offload where:
- The entire KV cache for a layer is on GPU during computation
- No need for block loading from CPU during attention
- estimate() returns a sparse mask (or None for full attention)
- compute_prefill()/compute_decode() perform the actual attention
Attributes:
supports_prefill: Whether this policy can be used for prefill phase.
supports_decode: Whether this policy can be used for decode phase.
Example:
class MyPolicy(AttentionPolicy):
supports_prefill = True
supports_decode = True
def estimate(self, q, k, layer_id):
# Return sparse mask or None
return None
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
# Compute attention
return flash_attn_varlen_func(q, k, v, ...)
"""
# Compatibility flags - override in subclasses
supports_prefill: bool = True
supports_decode: bool = True
def initialize(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
num_cpu_blocks: int,
dtype: torch.dtype,
device: torch.device = None,
) -> None:
"""
Initialize policy resources.
Called by the framework after KV cache is allocated. Override this
to create metadata structures or pre-allocate buffers.
Default implementation does nothing.
Args:
num_layers: Number of transformer layers
num_kv_heads: Number of KV attention heads
head_dim: Dimension per head
num_cpu_blocks: Number of CPU blocks allocated
dtype: Data type for tensors
device: Device for metadata storage (GPU recommended for performance)
"""
pass
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Estimate sparse attention mask.
For sparse policies (e.g., XAttention), computes block-level importance
and returns a boolean mask indicating which blocks to attend.
For full attention policy, returns None.
This corresponds to xattn_estimate() in COMPASS.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
or None for full attention
"""
return None
@abstractmethod
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute prefill attention.
The entire KV cache for this layer is on GPU. Compute attention
between Q and K/V, optionally using sparse mask from estimate().
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
pass
def compute_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute decode attention.
KV is provided from ring buffer, containing prefill tokens + decoded tokens.
Default implementation uses FlashAttention.
Args:
q: Query tensor [1, num_heads, head_dim]
k: Key tensor [context_len+1, num_kv_heads, head_dim]
v: Value tensor [context_len+1, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor
Returns:
Attention output [1, num_heads, head_dim]
"""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None:
"""
Reset policy state.
Called when starting a new sequence or clearing state.
Default implementation does nothing.
"""
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# Backward compatibility alias
SparsePolicy = AttentionPolicy