[WIP] need refactor.

This commit is contained in:
Zijie Tian
2026-01-22 22:20:34 +08:00
parent 69b779e252
commit 5fb0f67295
11 changed files with 514 additions and 548 deletions

View File

@@ -1,13 +1,18 @@
"""
Base class for sparse attention policies.
Base class for attention policies in layerwise offload mode.
Sparse attention policies determine which KV cache blocks to load
from CPU for each query chunk during chunked attention computation.
AttentionPolicy defines the interface for all attention computation,
including full attention and sparse attention methods like XAttention.
Key methods:
- estimate(): Compute sparse attention mask (optional, returns None for full attention)
- compute_prefill(): Compute prefill attention
- compute_decode(): Compute decode attention (default implementation provided)
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Any
from typing import List, Optional, Tuple
import torch
# Import SparsePolicyType from config to avoid circular imports
@@ -17,10 +22,10 @@ from nanovllm.config import SparsePolicyType
@dataclass
class PolicyContext:
"""
Context passed to sparse policy for block selection.
Context passed to attention policy for block selection.
This dataclass contains all information needed by a sparse policy
to decide which blocks to load for the current query chunk.
This dataclass contains all information needed by an attention policy
for sparse estimation and attention computation.
"""
query_chunk_idx: int
@@ -49,40 +54,41 @@ class PolicyContext:
"""Total KV sequence length so far (for reference)."""
class SparsePolicy(ABC):
class AttentionPolicy(ABC):
"""
Abstract base class for sparse attention policies.
Base class for attention policies in layerwise offload mode.
Subclass this and implement select_blocks() to create custom
sparse attention patterns. The policy receives context about
the current query chunk and returns which KV blocks to load.
All attention computation goes through a policy, including both
full attention and sparse attention methods.
The policy interface is designed for layerwise offload where:
- The entire KV cache for a layer is on GPU during computation
- No need for block loading from CPU during attention
- estimate() returns a sparse mask (or None for full attention)
- compute_prefill()/compute_decode() perform the actual attention
Attributes:
supports_prefill: Whether this policy can be used for prefill phase.
supports_decode: Whether this policy can be used for decode phase.
Example:
class MySparsePolicy(SparsePolicy):
supports_prefill = False # decode-only policy
class MyPolicy(AttentionPolicy):
supports_prefill = True
supports_decode = True
def select_blocks(self, available_blocks, ctx):
# Load first block and last 2 blocks
if len(available_blocks) <= 3:
return available_blocks
return [available_blocks[0]] + available_blocks[-2:]
def estimate(self, q, k, layer_id):
# Return sparse mask or None
return None
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
# Compute attention
return flash_attn_varlen_func(q, k, v, ...)
"""
# Compatibility flags - override in subclasses
supports_prefill: bool = True
supports_decode: bool = True
# Whether this policy requires selective block loading during decode
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
# Example: MInference=False (only affects attention), Quest=True (affects load)
requires_block_selection: bool = False
def initialize(
self,
num_layers: int,
@@ -96,7 +102,7 @@ class SparsePolicy(ABC):
Initialize policy resources.
Called by the framework after KV cache is allocated. Override this
to create metadata structures (e.g., BlockMetadataManager for Quest).
to create metadata structures or pre-allocate buffers.
Default implementation does nothing.
Args:
@@ -109,76 +115,98 @@ class SparsePolicy(ABC):
"""
pass
@abstractmethod
def select_blocks(
def estimate(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Select which KV blocks to load for the current query chunk.
Estimate sparse attention mask.
This is the core method that defines the sparse attention pattern.
The returned blocks will be loaded from CPU to GPU for attention
computation against the current query chunk.
For sparse policies (e.g., XAttention), computes block-level importance
and returns a boolean mask indicating which blocks to attend.
For full attention policy, returns None.
This corresponds to xattn_estimate() in COMPASS.
Args:
available_blocks: List of CPU block IDs that contain KV cache
from previous chunks. These are ordered by
their position in the sequence.
ctx: PolicyContext with information about the current query
chunk, layer, phase (prefill/decode), etc.
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
List of block IDs to load (must be a subset of available_blocks).
The order may affect performance (sequential access is faster).
Returning [] means no previous blocks will be loaded.
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
or None for full attention
"""
pass
return None
def on_prefill_offload(
@abstractmethod
def compute_prefill(
self,
cpu_block_id: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
softmax_scale: float,
) -> torch.Tensor:
"""
Hook called when a block is offloaded during prefill phase.
Compute prefill attention.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
Override this to collect metadata about blocks (e.g., min/max keys
for Quest-style selection). Default implementation does nothing.
The entire KV cache for this layer is on GPU. Compute attention
between Q and K/V, optionally using sparse mask from estimate().
Args:
cpu_block_id: The CPU block ID that will be written
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
pass
def on_decode_offload(
def compute_decode(
self,
cpu_block_id: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
softmax_scale: float,
) -> torch.Tensor:
"""
Hook called when a block is offloaded during decode phase.
Compute decode attention.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
Override this to update metadata about blocks. Default implementation
does nothing.
KV is provided from ring buffer, containing prefill tokens + decoded tokens.
Default implementation uses FlashAttention.
Args:
cpu_block_id: The CPU block ID that will be written
q: Query tensor [1, num_heads, head_dim]
k: Key tensor [context_len+1, num_kv_heads, head_dim]
v: Value tensor [context_len+1, num_kv_heads, head_dim]
layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
softmax_scale: Softmax scaling factor
Returns:
Attention output [1, num_heads, head_dim]
"""
pass
from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None:
"""
@@ -189,32 +217,9 @@ class SparsePolicy(ABC):
"""
pass
def sparse_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
) -> torch.Tensor:
"""
Compute sparse attention for prefill phase.
This method is called when supports_prefill=True and the policy
is used for GPU-only sparse prefill (no CPU offload).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current transformer layer index
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
"Set supports_prefill=False or implement this method."
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# Backward compatibility alias
SparsePolicy = AttentionPolicy