226 lines
6.8 KiB
Python
226 lines
6.8 KiB
Python
"""
|
|
Base class for attention policies in layerwise offload mode.
|
|
|
|
AttentionPolicy defines the interface for all attention computation,
|
|
including full attention and sparse attention methods like XAttention.
|
|
|
|
Key methods:
|
|
- estimate(): Compute sparse attention mask (optional, returns None for full attention)
|
|
- compute_prefill(): Compute prefill attention
|
|
- compute_decode(): Compute decode attention (default implementation provided)
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Tuple
|
|
import torch
|
|
|
|
# Import SparsePolicyType from config to avoid circular imports
|
|
from nanovllm.config import SparsePolicyType
|
|
|
|
|
|
@dataclass
|
|
class PolicyContext:
|
|
"""
|
|
Context passed to attention policy for block selection.
|
|
|
|
This dataclass contains all information needed by an attention policy
|
|
for sparse estimation and attention computation.
|
|
"""
|
|
|
|
query_chunk_idx: int
|
|
"""Index of the current query chunk (0-indexed)."""
|
|
|
|
num_query_chunks: int
|
|
"""Total number of query chunks in this prefill."""
|
|
|
|
layer_id: int
|
|
"""Current transformer layer index."""
|
|
|
|
query: Optional[torch.Tensor]
|
|
"""
|
|
Query tensor for current chunk.
|
|
Shape: [1, num_heads, head_dim] for decode, [1, seq_len, num_heads, head_dim] for prefill.
|
|
May be None if not available (e.g., some prefill scenarios).
|
|
"""
|
|
|
|
is_prefill: bool
|
|
"""True if in prefill phase, False if in decode phase."""
|
|
|
|
block_size: int = 1024
|
|
"""Number of tokens per block."""
|
|
|
|
total_kv_len: int = 0
|
|
"""Total KV sequence length so far (for reference)."""
|
|
|
|
|
|
class AttentionPolicy(ABC):
|
|
"""
|
|
Base class for attention policies in layerwise offload mode.
|
|
|
|
All attention computation goes through a policy, including both
|
|
full attention and sparse attention methods.
|
|
|
|
The policy interface is designed for layerwise offload where:
|
|
- The entire KV cache for a layer is on GPU during computation
|
|
- No need for block loading from CPU during attention
|
|
- estimate() returns a sparse mask (or None for full attention)
|
|
- compute_prefill()/compute_decode() perform the actual attention
|
|
|
|
Attributes:
|
|
supports_prefill: Whether this policy can be used for prefill phase.
|
|
supports_decode: Whether this policy can be used for decode phase.
|
|
|
|
Example:
|
|
class MyPolicy(AttentionPolicy):
|
|
supports_prefill = True
|
|
supports_decode = True
|
|
|
|
def estimate(self, q, k, layer_id):
|
|
# Return sparse mask or None
|
|
return None
|
|
|
|
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
|
# Compute attention
|
|
return flash_attn_varlen_func(q, k, v, ...)
|
|
"""
|
|
|
|
# Compatibility flags - override in subclasses
|
|
supports_prefill: bool = True
|
|
supports_decode: bool = True
|
|
|
|
def initialize(
|
|
self,
|
|
num_layers: int,
|
|
num_kv_heads: int,
|
|
head_dim: int,
|
|
num_cpu_blocks: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device = None,
|
|
) -> None:
|
|
"""
|
|
Initialize policy resources.
|
|
|
|
Called by the framework after KV cache is allocated. Override this
|
|
to create metadata structures or pre-allocate buffers.
|
|
Default implementation does nothing.
|
|
|
|
Args:
|
|
num_layers: Number of transformer layers
|
|
num_kv_heads: Number of KV attention heads
|
|
head_dim: Dimension per head
|
|
num_cpu_blocks: Number of CPU blocks allocated
|
|
dtype: Data type for tensors
|
|
device: Device for metadata storage (GPU recommended for performance)
|
|
"""
|
|
pass
|
|
|
|
def estimate(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
layer_id: int,
|
|
) -> Optional[torch.Tensor]:
|
|
"""
|
|
Estimate sparse attention mask.
|
|
|
|
For sparse policies (e.g., XAttention), computes block-level importance
|
|
and returns a boolean mask indicating which blocks to attend.
|
|
For full attention policy, returns None.
|
|
|
|
This corresponds to xattn_estimate() in COMPASS.
|
|
|
|
Args:
|
|
q: Query tensor [seq_len, num_heads, head_dim]
|
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
layer_id: Transformer layer index
|
|
|
|
Returns:
|
|
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
|
|
or None for full attention
|
|
"""
|
|
return None
|
|
|
|
@abstractmethod
|
|
def compute_prefill(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer_id: int,
|
|
softmax_scale: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute prefill attention.
|
|
|
|
The entire KV cache for this layer is on GPU. Compute attention
|
|
between Q and K/V, optionally using sparse mask from estimate().
|
|
|
|
Args:
|
|
q: Query tensor [seq_len, num_heads, head_dim]
|
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
layer_id: Transformer layer index
|
|
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
|
|
|
Returns:
|
|
Attention output [seq_len, num_heads, head_dim]
|
|
"""
|
|
pass
|
|
|
|
def compute_decode(
|
|
self,
|
|
q: torch.Tensor,
|
|
k: torch.Tensor,
|
|
v: torch.Tensor,
|
|
layer_id: int,
|
|
softmax_scale: float,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Compute decode attention.
|
|
|
|
KV is provided from ring buffer, containing prefill tokens + decoded tokens.
|
|
Default implementation uses FlashAttention.
|
|
|
|
Args:
|
|
q: Query tensor [1, num_heads, head_dim]
|
|
k: Key tensor [context_len+1, num_kv_heads, head_dim]
|
|
v: Value tensor [context_len+1, num_kv_heads, head_dim]
|
|
layer_id: Transformer layer index
|
|
softmax_scale: Softmax scaling factor
|
|
|
|
Returns:
|
|
Attention output [1, num_heads, head_dim]
|
|
"""
|
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
|
|
context_len = k.shape[0]
|
|
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
|
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
|
|
|
return flash_attn_varlen_func(
|
|
q, k, v,
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
max_seqlen_q=1,
|
|
max_seqlen_k=context_len,
|
|
softmax_scale=softmax_scale,
|
|
causal=False,
|
|
)
|
|
|
|
def reset(self) -> None:
|
|
"""
|
|
Reset policy state.
|
|
|
|
Called when starting a new sequence or clearing state.
|
|
Default implementation does nothing.
|
|
"""
|
|
pass
|
|
|
|
def __repr__(self) -> str:
|
|
return f"{self.__class__.__name__}()"
|
|
|
|
|
|
# Backward compatibility alias
|
|
SparsePolicy = AttentionPolicy
|