[WIP] need refactor.
This commit is contained in:
@@ -1,13 +1,18 @@
|
||||
"""
|
||||
Base class for sparse attention policies.
|
||||
Base class for attention policies in layerwise offload mode.
|
||||
|
||||
Sparse attention policies determine which KV cache blocks to load
|
||||
from CPU for each query chunk during chunked attention computation.
|
||||
AttentionPolicy defines the interface for all attention computation,
|
||||
including full attention and sparse attention methods like XAttention.
|
||||
|
||||
Key methods:
|
||||
- estimate(): Compute sparse attention mask (optional, returns None for full attention)
|
||||
- compute_prefill(): Compute prefill attention
|
||||
- compute_decode(): Compute decode attention (default implementation provided)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Any
|
||||
from typing import List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
# Import SparsePolicyType from config to avoid circular imports
|
||||
@@ -17,10 +22,10 @@ from nanovllm.config import SparsePolicyType
|
||||
@dataclass
|
||||
class PolicyContext:
|
||||
"""
|
||||
Context passed to sparse policy for block selection.
|
||||
Context passed to attention policy for block selection.
|
||||
|
||||
This dataclass contains all information needed by a sparse policy
|
||||
to decide which blocks to load for the current query chunk.
|
||||
This dataclass contains all information needed by an attention policy
|
||||
for sparse estimation and attention computation.
|
||||
"""
|
||||
|
||||
query_chunk_idx: int
|
||||
@@ -49,40 +54,41 @@ class PolicyContext:
|
||||
"""Total KV sequence length so far (for reference)."""
|
||||
|
||||
|
||||
class SparsePolicy(ABC):
|
||||
class AttentionPolicy(ABC):
|
||||
"""
|
||||
Abstract base class for sparse attention policies.
|
||||
Base class for attention policies in layerwise offload mode.
|
||||
|
||||
Subclass this and implement select_blocks() to create custom
|
||||
sparse attention patterns. The policy receives context about
|
||||
the current query chunk and returns which KV blocks to load.
|
||||
All attention computation goes through a policy, including both
|
||||
full attention and sparse attention methods.
|
||||
|
||||
The policy interface is designed for layerwise offload where:
|
||||
- The entire KV cache for a layer is on GPU during computation
|
||||
- No need for block loading from CPU during attention
|
||||
- estimate() returns a sparse mask (or None for full attention)
|
||||
- compute_prefill()/compute_decode() perform the actual attention
|
||||
|
||||
Attributes:
|
||||
supports_prefill: Whether this policy can be used for prefill phase.
|
||||
supports_decode: Whether this policy can be used for decode phase.
|
||||
|
||||
Example:
|
||||
class MySparsePolicy(SparsePolicy):
|
||||
supports_prefill = False # decode-only policy
|
||||
class MyPolicy(AttentionPolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
|
||||
def select_blocks(self, available_blocks, ctx):
|
||||
# Load first block and last 2 blocks
|
||||
if len(available_blocks) <= 3:
|
||||
return available_blocks
|
||||
return [available_blocks[0]] + available_blocks[-2:]
|
||||
def estimate(self, q, k, layer_id):
|
||||
# Return sparse mask or None
|
||||
return None
|
||||
|
||||
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
||||
# Compute attention
|
||||
return flash_attn_varlen_func(q, k, v, ...)
|
||||
"""
|
||||
|
||||
# Compatibility flags - override in subclasses
|
||||
supports_prefill: bool = True
|
||||
supports_decode: bool = True
|
||||
|
||||
# Whether this policy requires selective block loading during decode
|
||||
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
|
||||
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
|
||||
# Example: MInference=False (only affects attention), Quest=True (affects load)
|
||||
requires_block_selection: bool = False
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
num_layers: int,
|
||||
@@ -96,7 +102,7 @@ class SparsePolicy(ABC):
|
||||
Initialize policy resources.
|
||||
|
||||
Called by the framework after KV cache is allocated. Override this
|
||||
to create metadata structures (e.g., BlockMetadataManager for Quest).
|
||||
to create metadata structures or pre-allocate buffers.
|
||||
Default implementation does nothing.
|
||||
|
||||
Args:
|
||||
@@ -109,76 +115,98 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def select_blocks(
|
||||
def estimate(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Select which KV blocks to load for the current query chunk.
|
||||
Estimate sparse attention mask.
|
||||
|
||||
This is the core method that defines the sparse attention pattern.
|
||||
The returned blocks will be loaded from CPU to GPU for attention
|
||||
computation against the current query chunk.
|
||||
For sparse policies (e.g., XAttention), computes block-level importance
|
||||
and returns a boolean mask indicating which blocks to attend.
|
||||
For full attention policy, returns None.
|
||||
|
||||
This corresponds to xattn_estimate() in COMPASS.
|
||||
|
||||
Args:
|
||||
available_blocks: List of CPU block IDs that contain KV cache
|
||||
from previous chunks. These are ordered by
|
||||
their position in the sequence.
|
||||
ctx: PolicyContext with information about the current query
|
||||
chunk, layer, phase (prefill/decode), etc.
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
|
||||
Returns:
|
||||
List of block IDs to load (must be a subset of available_blocks).
|
||||
The order may affect performance (sequential access is faster).
|
||||
Returning [] means no previous blocks will be loaded.
|
||||
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
|
||||
or None for full attention
|
||||
"""
|
||||
pass
|
||||
return None
|
||||
|
||||
def on_prefill_offload(
|
||||
@abstractmethod
|
||||
def compute_prefill(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Hook called when a block is offloaded during prefill phase.
|
||||
Compute prefill attention.
|
||||
|
||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||
Override this to collect metadata about blocks (e.g., min/max keys
|
||||
for Quest-style selection). Default implementation does nothing.
|
||||
The entire KV cache for this layer is on GPU. Compute attention
|
||||
between Q and K/V, optionally using sparse mask from estimate().
|
||||
|
||||
Args:
|
||||
cpu_block_id: The CPU block ID that will be written
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_decode_offload(
|
||||
def compute_decode(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Hook called when a block is offloaded during decode phase.
|
||||
Compute decode attention.
|
||||
|
||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||
Override this to update metadata about blocks. Default implementation
|
||||
does nothing.
|
||||
KV is provided from ring buffer, containing prefill tokens + decoded tokens.
|
||||
Default implementation uses FlashAttention.
|
||||
|
||||
Args:
|
||||
cpu_block_id: The CPU block ID that will be written
|
||||
q: Query tensor [1, num_heads, head_dim]
|
||||
k: Key tensor [context_len+1, num_kv_heads, head_dim]
|
||||
v: Value tensor [context_len+1, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
softmax_scale: Softmax scaling factor
|
||||
|
||||
Returns:
|
||||
Attention output [1, num_heads, head_dim]
|
||||
"""
|
||||
pass
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
context_len = k.shape[0]
|
||||
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=1,
|
||||
max_seqlen_k=context_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
@@ -189,32 +217,9 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def sparse_prefill_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute sparse attention for prefill phase.
|
||||
|
||||
This method is called when supports_prefill=True and the policy
|
||||
is used for GPU-only sparse prefill (no CPU offload).
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Current transformer layer index
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
|
||||
"Set supports_prefill=False or implement this method."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
SparsePolicy = AttentionPolicy
|
||||
|
||||
Reference in New Issue
Block a user