[WIP] need refactor.
This commit is contained in:
@@ -62,6 +62,7 @@ class Config:
|
||||
xattn_keep_sink: bool = False # Always keep first block (sink tokens)
|
||||
xattn_keep_recent: bool = False # Always keep recent diagonal blocks
|
||||
xattn_norm: float = 1.0 # Normalization factor for attention scores
|
||||
xattn_use_bsa: bool = True # Use Block Sparse Attention library (requires installation)
|
||||
|
||||
def __post_init__(self):
|
||||
assert os.path.isdir(self.model)
|
||||
|
||||
@@ -57,8 +57,8 @@ class ModelRunner:
|
||||
load_model(self.model, config.model)
|
||||
self.sampler = GreedySampler()
|
||||
|
||||
# Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache)
|
||||
self.sparse_prefill_policy = None
|
||||
# Initialize attention_policy before warmup (will be configured in allocate_kv_cache)
|
||||
self.attention_policy = None
|
||||
|
||||
#> Disable warmup for debugging
|
||||
self.warmup_model()
|
||||
@@ -178,38 +178,35 @@ class ModelRunner:
|
||||
# Create KV cache manager using factory
|
||||
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
|
||||
|
||||
# Create sparse prefill policy
|
||||
# This is used for both GPU-only and CPU offload modes when policy supports prefill
|
||||
self.sparse_prefill_policy = None
|
||||
if config.sparse_policy != SparsePolicyType.FULL:
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||
# Create attention policy (always, including FULL)
|
||||
# In layerwise offload mode, all attention goes through the policy
|
||||
from nanovllm.kvcache.sparse import create_attention_policy
|
||||
|
||||
# Get policy-specific parameters based on type
|
||||
if config.sparse_policy == SparsePolicyType.XATTN:
|
||||
policy_kwargs = {
|
||||
"stride": config.xattn_stride,
|
||||
"threshold": config.xattn_threshold,
|
||||
"chunk_size": config.xattn_chunk_size,
|
||||
"use_triton": config.xattn_use_triton,
|
||||
"keep_sink": config.xattn_keep_sink,
|
||||
"keep_recent": config.xattn_keep_recent,
|
||||
"norm": config.xattn_norm,
|
||||
}
|
||||
else: # MINFERENCE or others
|
||||
policy_kwargs = {
|
||||
"vertical_size": config.minference_vertical_size,
|
||||
"slash_size": config.minference_slash_size,
|
||||
"adaptive_budget": config.minference_adaptive_budget,
|
||||
"num_sink_tokens": config.minference_num_sink_tokens,
|
||||
"num_recent_diags": config.minference_num_recent_diags,
|
||||
}
|
||||
# Get policy-specific parameters based on type
|
||||
if config.sparse_policy == SparsePolicyType.XATTN:
|
||||
policy_kwargs = {
|
||||
"stride": config.xattn_stride,
|
||||
"threshold": config.xattn_threshold,
|
||||
"chunk_size": config.xattn_chunk_size,
|
||||
"use_triton": config.xattn_use_triton,
|
||||
"keep_sink": config.xattn_keep_sink,
|
||||
"keep_recent": config.xattn_keep_recent,
|
||||
"norm": config.xattn_norm,
|
||||
"use_bsa": config.xattn_use_bsa,
|
||||
}
|
||||
elif config.sparse_policy == SparsePolicyType.MINFERENCE:
|
||||
policy_kwargs = {
|
||||
"vertical_size": config.minference_vertical_size,
|
||||
"slash_size": config.minference_slash_size,
|
||||
"adaptive_budget": config.minference_adaptive_budget,
|
||||
"num_sink_tokens": config.minference_num_sink_tokens,
|
||||
"num_recent_diags": config.minference_num_recent_diags,
|
||||
}
|
||||
else: # FULL or QUEST
|
||||
policy_kwargs = {}
|
||||
|
||||
policy = create_sparse_policy(config.sparse_policy, **policy_kwargs)
|
||||
|
||||
# Only use if policy supports sparse prefill
|
||||
if policy.supports_prefill:
|
||||
self.sparse_prefill_policy = policy
|
||||
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}")
|
||||
self.attention_policy = create_attention_policy(config.sparse_policy, **policy_kwargs)
|
||||
logger.info(f"Attention policy: {self.attention_policy}")
|
||||
|
||||
# Allocate cache through manager
|
||||
self.kvcache_manager.allocate_cache(
|
||||
@@ -395,7 +392,7 @@ class ModelRunner:
|
||||
|
||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||
slot_mapping, None, block_tables,
|
||||
sparse_prefill_policy=self.sparse_prefill_policy)
|
||||
attention_policy=self.attention_policy)
|
||||
return input_ids, positions
|
||||
|
||||
def prepare_decode(self, seqs: list[Sequence]):
|
||||
@@ -592,21 +589,11 @@ class ModelRunner:
|
||||
# RoPE
|
||||
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
||||
|
||||
# Sparse or Full attention (uses k, v directly - before store!)
|
||||
if self.sparse_prefill_policy is not None:
|
||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
||||
q, k, v, layer_id
|
||||
)
|
||||
else:
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=total_tokens,
|
||||
max_seqlen_k=total_tokens,
|
||||
softmax_scale=layer.self_attn.attn.scale,
|
||||
causal=True,
|
||||
)
|
||||
# Compute attention using policy (uses k, v directly - before store!)
|
||||
attn_output = self.attention_policy.compute_prefill(
|
||||
q, k, v, layer_id,
|
||||
softmax_scale=layer.self_attn.attn.scale,
|
||||
)
|
||||
|
||||
# O projection
|
||||
attn_output = attn_output.view(total_tokens, -1)
|
||||
@@ -872,23 +859,11 @@ class ModelRunner:
|
||||
# RoPE
|
||||
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
||||
|
||||
# Sparse or Full attention
|
||||
if self.sparse_prefill_policy is not None:
|
||||
# MInference or other sparse prefill policy
|
||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
||||
q, k, v, layer_id
|
||||
)
|
||||
else:
|
||||
# Full attention using FlashAttention
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=total_tokens,
|
||||
max_seqlen_k=total_tokens,
|
||||
softmax_scale=layer.self_attn.attn.scale,
|
||||
causal=True,
|
||||
)
|
||||
# Compute attention using policy
|
||||
attn_output = self.attention_policy.compute_prefill(
|
||||
q, k, v, layer_id,
|
||||
softmax_scale=layer.self_attn.attn.scale,
|
||||
)
|
||||
|
||||
# O projection
|
||||
attn_output = attn_output.view(total_tokens, -1)
|
||||
|
||||
@@ -1,49 +1,56 @@
|
||||
"""
|
||||
Sparse Attention Policy module.
|
||||
Attention Policy module for layerwise offload mode.
|
||||
|
||||
Provides pluggable policies for selecting which KV blocks to load
|
||||
during chunked attention with CPU offload.
|
||||
Provides pluggable policies for attention computation:
|
||||
- FullAttentionPolicy: Standard FlashAttention (no sparsity)
|
||||
- XAttentionPolicy: Sparse prefill using XAttention algorithm
|
||||
- MInferencePolicy: MInference sparse attention
|
||||
- QuestPolicy: Quest block selection (for chunked offload)
|
||||
|
||||
Usage:
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
||||
from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType
|
||||
|
||||
# Create policy using factory function
|
||||
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
|
||||
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
|
||||
|
||||
# Use policy for attention
|
||||
attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
|
||||
|
||||
# Or create custom policy
|
||||
class MyPolicy(SparsePolicy):
|
||||
class MyPolicy(AttentionPolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
|
||||
def select_blocks(self, available_blocks, ctx):
|
||||
return available_blocks[:5] # Just first 5 blocks
|
||||
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
||||
# Custom attention computation
|
||||
...
|
||||
"""
|
||||
|
||||
from nanovllm.config import SparsePolicyType
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||
from nanovllm.kvcache.sparse.minference import MInferencePolicy
|
||||
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
|
||||
|
||||
|
||||
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||
def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy:
|
||||
"""
|
||||
Create a sparse policy instance from an enum type.
|
||||
Create an attention policy instance from an enum type.
|
||||
|
||||
The returned policy is not yet initialized. Call policy.initialize()
|
||||
or let the framework call it during KV cache allocation.
|
||||
All attention (including full attention) goes through a policy in layerwise
|
||||
offload mode. The policy is responsible for computing prefill/decode attention.
|
||||
|
||||
Args:
|
||||
policy_type: SparsePolicyType enum value
|
||||
policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST)
|
||||
**kwargs: Policy-specific configuration options
|
||||
|
||||
Returns:
|
||||
SparsePolicy instance (not initialized)
|
||||
AttentionPolicy instance
|
||||
|
||||
Example:
|
||||
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
|
||||
policy.initialize(num_layers=28, num_kv_heads=8, ...)
|
||||
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
|
||||
attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
|
||||
"""
|
||||
if policy_type == SparsePolicyType.FULL:
|
||||
return FullAttentionPolicy()
|
||||
@@ -75,21 +82,32 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
|
||||
keep_sink=kwargs.get("keep_sink", False),
|
||||
keep_recent=kwargs.get("keep_recent", False),
|
||||
norm=kwargs.get("norm", 1.0),
|
||||
use_bsa=kwargs.get("use_bsa", True),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown policy type: {policy_type}")
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
create_sparse_policy = create_attention_policy
|
||||
|
||||
|
||||
__all__ = [
|
||||
# New interface
|
||||
"AttentionPolicy",
|
||||
"create_attention_policy",
|
||||
# Backward compatibility
|
||||
"SparsePolicy",
|
||||
"create_sparse_policy",
|
||||
# Common types
|
||||
"PolicyContext",
|
||||
"SparsePolicyType",
|
||||
# Policy implementations
|
||||
"FullAttentionPolicy",
|
||||
"QuestPolicy",
|
||||
"QuestConfig",
|
||||
"BlockMetadataManager",
|
||||
"MInferencePolicy",
|
||||
"XAttentionPolicy",
|
||||
"create_sparse_policy",
|
||||
]
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
"""
|
||||
Full attention policy - loads all blocks (no sparsity).
|
||||
Full attention policy - standard FlashAttention without sparsity.
|
||||
|
||||
This serves as a baseline and default policy when sparse
|
||||
attention is not needed.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from .policy import SparsePolicy, PolicyContext
|
||||
from typing import Optional
|
||||
import torch
|
||||
from .policy import AttentionPolicy
|
||||
|
||||
|
||||
class FullAttentionPolicy(SparsePolicy):
|
||||
class FullAttentionPolicy(AttentionPolicy):
|
||||
"""
|
||||
Full attention policy that loads all available blocks.
|
||||
Full attention policy using FlashAttention (no sparsity).
|
||||
|
||||
This is the default behavior with no sparsity - all previous
|
||||
KV cache blocks are loaded for each query chunk.
|
||||
This is the default behavior with standard causal attention.
|
||||
All tokens attend to all previous tokens.
|
||||
|
||||
Use this as:
|
||||
- A baseline for comparing sparse policies
|
||||
@@ -25,15 +26,55 @@ class FullAttentionPolicy(SparsePolicy):
|
||||
# Full attention supports both prefill and decode
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
requires_block_selection = False # Load all blocks, no selective loading
|
||||
|
||||
def select_blocks(
|
||||
def estimate(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""Return all blocks - no sparsity."""
|
||||
return available_blocks
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Full attention - no sparse mask needed.
|
||||
|
||||
Returns None to indicate full attention should be used.
|
||||
"""
|
||||
return None
|
||||
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute full causal attention using FlashAttention.
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
seq_len = q.shape[0]
|
||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "FullAttentionPolicy()"
|
||||
|
||||
@@ -10,10 +10,10 @@ from typing import List, Tuple, Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.policy import AttentionPolicy, PolicyContext
|
||||
|
||||
|
||||
class MInferencePolicy(SparsePolicy):
|
||||
class MInferencePolicy(AttentionPolicy):
|
||||
"""
|
||||
MInference sparse prefill policy using vertical + slash pattern.
|
||||
|
||||
@@ -347,6 +347,33 @@ class MInferencePolicy(SparsePolicy):
|
||||
|
||||
return o
|
||||
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute MInference sparse prefill attention.
|
||||
|
||||
This is the new unified interface for attention policies.
|
||||
Delegates to sparse_prefill_attention (ignores softmax_scale as MInference
|
||||
computes it internally from head_dim).
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
softmax_scale: Softmax scaling factor (unused, computed internally)
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
return self.sparse_prefill_attention(q, k, v, layer_id)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"MInferencePolicy("
|
||||
f"adaptive_budget={self.adaptive_budget}, "
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
"""
|
||||
Base class for sparse attention policies.
|
||||
Base class for attention policies in layerwise offload mode.
|
||||
|
||||
Sparse attention policies determine which KV cache blocks to load
|
||||
from CPU for each query chunk during chunked attention computation.
|
||||
AttentionPolicy defines the interface for all attention computation,
|
||||
including full attention and sparse attention methods like XAttention.
|
||||
|
||||
Key methods:
|
||||
- estimate(): Compute sparse attention mask (optional, returns None for full attention)
|
||||
- compute_prefill(): Compute prefill attention
|
||||
- compute_decode(): Compute decode attention (default implementation provided)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Any
|
||||
from typing import List, Optional, Tuple
|
||||
import torch
|
||||
|
||||
# Import SparsePolicyType from config to avoid circular imports
|
||||
@@ -17,10 +22,10 @@ from nanovllm.config import SparsePolicyType
|
||||
@dataclass
|
||||
class PolicyContext:
|
||||
"""
|
||||
Context passed to sparse policy for block selection.
|
||||
Context passed to attention policy for block selection.
|
||||
|
||||
This dataclass contains all information needed by a sparse policy
|
||||
to decide which blocks to load for the current query chunk.
|
||||
This dataclass contains all information needed by an attention policy
|
||||
for sparse estimation and attention computation.
|
||||
"""
|
||||
|
||||
query_chunk_idx: int
|
||||
@@ -49,40 +54,41 @@ class PolicyContext:
|
||||
"""Total KV sequence length so far (for reference)."""
|
||||
|
||||
|
||||
class SparsePolicy(ABC):
|
||||
class AttentionPolicy(ABC):
|
||||
"""
|
||||
Abstract base class for sparse attention policies.
|
||||
Base class for attention policies in layerwise offload mode.
|
||||
|
||||
Subclass this and implement select_blocks() to create custom
|
||||
sparse attention patterns. The policy receives context about
|
||||
the current query chunk and returns which KV blocks to load.
|
||||
All attention computation goes through a policy, including both
|
||||
full attention and sparse attention methods.
|
||||
|
||||
The policy interface is designed for layerwise offload where:
|
||||
- The entire KV cache for a layer is on GPU during computation
|
||||
- No need for block loading from CPU during attention
|
||||
- estimate() returns a sparse mask (or None for full attention)
|
||||
- compute_prefill()/compute_decode() perform the actual attention
|
||||
|
||||
Attributes:
|
||||
supports_prefill: Whether this policy can be used for prefill phase.
|
||||
supports_decode: Whether this policy can be used for decode phase.
|
||||
|
||||
Example:
|
||||
class MySparsePolicy(SparsePolicy):
|
||||
supports_prefill = False # decode-only policy
|
||||
class MyPolicy(AttentionPolicy):
|
||||
supports_prefill = True
|
||||
supports_decode = True
|
||||
|
||||
def select_blocks(self, available_blocks, ctx):
|
||||
# Load first block and last 2 blocks
|
||||
if len(available_blocks) <= 3:
|
||||
return available_blocks
|
||||
return [available_blocks[0]] + available_blocks[-2:]
|
||||
def estimate(self, q, k, layer_id):
|
||||
# Return sparse mask or None
|
||||
return None
|
||||
|
||||
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
||||
# Compute attention
|
||||
return flash_attn_varlen_func(q, k, v, ...)
|
||||
"""
|
||||
|
||||
# Compatibility flags - override in subclasses
|
||||
supports_prefill: bool = True
|
||||
supports_decode: bool = True
|
||||
|
||||
# Whether this policy requires selective block loading during decode
|
||||
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
|
||||
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
|
||||
# Example: MInference=False (only affects attention), Quest=True (affects load)
|
||||
requires_block_selection: bool = False
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
num_layers: int,
|
||||
@@ -96,7 +102,7 @@ class SparsePolicy(ABC):
|
||||
Initialize policy resources.
|
||||
|
||||
Called by the framework after KV cache is allocated. Override this
|
||||
to create metadata structures (e.g., BlockMetadataManager for Quest).
|
||||
to create metadata structures or pre-allocate buffers.
|
||||
Default implementation does nothing.
|
||||
|
||||
Args:
|
||||
@@ -109,76 +115,98 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def select_blocks(
|
||||
def estimate(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Select which KV blocks to load for the current query chunk.
|
||||
Estimate sparse attention mask.
|
||||
|
||||
This is the core method that defines the sparse attention pattern.
|
||||
The returned blocks will be loaded from CPU to GPU for attention
|
||||
computation against the current query chunk.
|
||||
For sparse policies (e.g., XAttention), computes block-level importance
|
||||
and returns a boolean mask indicating which blocks to attend.
|
||||
For full attention policy, returns None.
|
||||
|
||||
This corresponds to xattn_estimate() in COMPASS.
|
||||
|
||||
Args:
|
||||
available_blocks: List of CPU block IDs that contain KV cache
|
||||
from previous chunks. These are ordered by
|
||||
their position in the sequence.
|
||||
ctx: PolicyContext with information about the current query
|
||||
chunk, layer, phase (prefill/decode), etc.
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
|
||||
Returns:
|
||||
List of block IDs to load (must be a subset of available_blocks).
|
||||
The order may affect performance (sequential access is faster).
|
||||
Returning [] means no previous blocks will be loaded.
|
||||
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
|
||||
or None for full attention
|
||||
"""
|
||||
pass
|
||||
return None
|
||||
|
||||
def on_prefill_offload(
|
||||
@abstractmethod
|
||||
def compute_prefill(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Hook called when a block is offloaded during prefill phase.
|
||||
Compute prefill attention.
|
||||
|
||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||
Override this to collect metadata about blocks (e.g., min/max keys
|
||||
for Quest-style selection). Default implementation does nothing.
|
||||
The entire KV cache for this layer is on GPU. Compute attention
|
||||
between Q and K/V, optionally using sparse mask from estimate().
|
||||
|
||||
Args:
|
||||
cpu_block_id: The CPU block ID that will be written
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
pass
|
||||
|
||||
def on_decode_offload(
|
||||
def compute_decode(
|
||||
self,
|
||||
cpu_block_id: int,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
k_cache: torch.Tensor,
|
||||
num_valid_tokens: int,
|
||||
) -> None:
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Hook called when a block is offloaded during decode phase.
|
||||
Compute decode attention.
|
||||
|
||||
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||
Override this to update metadata about blocks. Default implementation
|
||||
does nothing.
|
||||
KV is provided from ring buffer, containing prefill tokens + decoded tokens.
|
||||
Default implementation uses FlashAttention.
|
||||
|
||||
Args:
|
||||
cpu_block_id: The CPU block ID that will be written
|
||||
q: Query tensor [1, num_heads, head_dim]
|
||||
k: Key tensor [context_len+1, num_kv_heads, head_dim]
|
||||
v: Value tensor [context_len+1, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||
num_valid_tokens: Number of valid tokens in this block
|
||||
softmax_scale: Softmax scaling factor
|
||||
|
||||
Returns:
|
||||
Attention output [1, num_heads, head_dim]
|
||||
"""
|
||||
pass
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
context_len = k.shape[0]
|
||||
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=1,
|
||||
max_seqlen_k=context_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""
|
||||
@@ -189,32 +217,9 @@ class SparsePolicy(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def sparse_prefill_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute sparse attention for prefill phase.
|
||||
|
||||
This method is called when supports_prefill=True and the policy
|
||||
is used for GPU-only sparse prefill (no CPU offload).
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Current transformer layer index
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
|
||||
"Set supports_prefill=False or implement this method."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}()"
|
||||
|
||||
|
||||
# Backward compatibility alias
|
||||
SparsePolicy = AttentionPolicy
|
||||
|
||||
@@ -11,7 +11,7 @@ import logging
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, Optional
|
||||
from .policy import SparsePolicy, PolicyContext
|
||||
from .policy import AttentionPolicy, PolicyContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -137,7 +137,7 @@ class QuestConfig:
|
||||
"""Always include this many recent blocks (last N blocks), in addition to Top-K."""
|
||||
|
||||
|
||||
class QuestPolicy(SparsePolicy):
|
||||
class QuestPolicy(AttentionPolicy):
|
||||
"""
|
||||
Quest-style Top-K block selection using min/max key bounds.
|
||||
|
||||
@@ -317,6 +317,25 @@ class QuestPolicy(SparsePolicy):
|
||||
if self.metadata is not None:
|
||||
self.metadata.reset()
|
||||
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Quest does not support prefill - raises error.
|
||||
|
||||
Quest is a decode-only policy for selective block loading.
|
||||
For prefill, use FullAttentionPolicy or XAttentionPolicy.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"QuestPolicy does not support prefill. "
|
||||
"Use FullAttentionPolicy or XAttentionPolicy for prefill."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"QuestPolicy(topk={self.config.topk_blocks}, "
|
||||
|
||||
@@ -4,48 +4,56 @@ XAttention sparse attention policy for nano-vllm.
|
||||
Implements the XAttention algorithm from COMPASS, using chunked estimation
|
||||
and block sparse attention for efficient long-context inference.
|
||||
|
||||
Architecture:
|
||||
XAttention = Estimate (Triton) + Compute (BSA)
|
||||
- Estimate: xattn_estimate() computes block-level importance scores
|
||||
- Compute: block_sparse_attn_func() executes sparse attention
|
||||
|
||||
Reference: COMPASS/compass/src/Xattention.py
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.kernels import (
|
||||
flat_group_gemm_fuse_reshape,
|
||||
softmax_fuse_block_sum,
|
||||
)
|
||||
from nanovllm.kvcache.sparse.utils import find_blocks_chunked
|
||||
from nanovllm.kvcache.sparse.policy import AttentionPolicy
|
||||
|
||||
# BSA block size is fixed at 128 (hardcoded in block_sparse_attn)
|
||||
BSA_BLOCK_SIZE = 128
|
||||
|
||||
|
||||
class XAttentionPolicy(SparsePolicy):
|
||||
class XAttentionPolicy(AttentionPolicy):
|
||||
"""
|
||||
XAttention sparse prefill policy using chunked estimation + block sparse attention.
|
||||
|
||||
This policy estimates sparse attention patterns by:
|
||||
1. Chunked QK computation using Triton kernels
|
||||
1. Chunked QK computation using Triton kernels (via nanovllm.ops.xattn)
|
||||
2. Block-wise softmax with importance scores
|
||||
3. Block selection based on threshold
|
||||
4. Block sparse attention computation
|
||||
4. Block sparse attention computation using MIT-HAN-LAB BSA library
|
||||
|
||||
The key method is estimate() which calls xattn_estimate() from nanovllm.ops
|
||||
to compute the sparse attention mask.
|
||||
|
||||
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
|
||||
BSA library: https://github.com/mit-han-lab/Block-Sparse-Attention
|
||||
"""
|
||||
|
||||
supports_prefill = True
|
||||
supports_decode = False # XAttention is prefill-only
|
||||
requires_block_selection = False # Only affects attention computation
|
||||
supports_decode = True # Uses default FlashAttention for decode
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stride: int = 8,
|
||||
threshold: float = 0.9,
|
||||
chunk_size: Optional[int] = None,
|
||||
block_size: int = 128,
|
||||
chunk_size: int = 16384,
|
||||
use_triton: bool = True,
|
||||
keep_sink: bool = False,
|
||||
keep_recent: bool = False,
|
||||
norm: float = 1.0,
|
||||
use_bsa: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize XAttention policy.
|
||||
@@ -53,19 +61,28 @@ class XAttentionPolicy(SparsePolicy):
|
||||
Args:
|
||||
stride: Stride for reorganizing Q/K (default: 8)
|
||||
threshold: Block selection threshold, 0-1 (default: 0.9)
|
||||
chunk_size: Chunk size for estimation (auto if None)
|
||||
block_size: Block size for sparse attention (default: 128, must match BSA)
|
||||
chunk_size: Chunk size for estimation (default: 16384)
|
||||
use_triton: Use Triton kernels (requires SM 80+)
|
||||
keep_sink: Always keep first block (sink tokens)
|
||||
keep_recent: Always keep recent diagonal blocks
|
||||
norm: Normalization factor for attention scores
|
||||
use_bsa: Use Block Sparse Attention library (default: True)
|
||||
"""
|
||||
self.stride = stride
|
||||
self.threshold = threshold
|
||||
self.block_size = block_size
|
||||
self.chunk_size = chunk_size
|
||||
self.use_triton = use_triton
|
||||
self.keep_sink = keep_sink
|
||||
self.keep_recent = keep_recent
|
||||
self.norm = norm
|
||||
self.use_bsa = use_bsa
|
||||
|
||||
# BSA requires block_size = 128
|
||||
if self.use_bsa and self.block_size != BSA_BLOCK_SIZE:
|
||||
print(f"XAttention: BSA requires block_size=128, adjusting from {self.block_size}")
|
||||
self.block_size = BSA_BLOCK_SIZE
|
||||
|
||||
# Check Triton availability
|
||||
if self.use_triton:
|
||||
@@ -79,379 +96,206 @@ class XAttentionPolicy(SparsePolicy):
|
||||
self.use_triton = False
|
||||
print("XAttention: Triton not available. Falling back to PyTorch.")
|
||||
|
||||
def select_blocks(
|
||||
# Check BSA availability
|
||||
if self.use_bsa:
|
||||
try:
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
except ImportError:
|
||||
self.use_bsa = False
|
||||
print("XAttention: block_sparse_attn not available. Falling back to FlashAttention.")
|
||||
|
||||
def estimate(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Select blocks for decode phase.
|
||||
Estimate sparse attention mask using XAttention algorithm.
|
||||
|
||||
XAttention is prefill-only, so this method is only used as a fallback.
|
||||
Returns all available blocks by default.
|
||||
Calls xattn_estimate() from nanovllm.ops.xattn to compute block-level
|
||||
importance scores and generate a sparse boolean mask.
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Transformer layer index
|
||||
|
||||
Returns:
|
||||
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
|
||||
or None if estimation fails (fallback to full attention)
|
||||
"""
|
||||
# XAttention is prefill-only, but we need to implement this abstract method
|
||||
# Since requires_block_selection=False, this won't be called for loading
|
||||
return available_blocks
|
||||
try:
|
||||
from nanovllm.ops.xattn import xattn_estimate
|
||||
|
||||
def sparse_prefill_attention(
|
||||
seq_len, num_heads, head_dim = q.shape
|
||||
num_kv_heads = k.shape[1]
|
||||
|
||||
# Convert to [batch, heads, seq, dim] format expected by xattn_estimate
|
||||
q_bhsd = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
|
||||
k_bhsd = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, seq_len, head_dim]
|
||||
|
||||
# Handle GQA: expand k to match q heads for estimation
|
||||
if num_kv_heads != num_heads:
|
||||
# GQA: expand k by repeating
|
||||
repeat_factor = num_heads // num_kv_heads
|
||||
k_bhsd = k_bhsd.repeat(1, repeat_factor, 1, 1)
|
||||
|
||||
# Call xattn_estimate
|
||||
attn_sums, sparse_mask = xattn_estimate(
|
||||
q_bhsd, k_bhsd,
|
||||
block_size=self.block_size,
|
||||
stride=self.stride,
|
||||
norm=self.norm,
|
||||
threshold=self.threshold,
|
||||
chunk_size=self.chunk_size,
|
||||
use_triton=self.use_triton,
|
||||
causal=True,
|
||||
keep_sink=self.keep_sink,
|
||||
keep_recent=self.keep_recent,
|
||||
)
|
||||
|
||||
return sparse_mask
|
||||
|
||||
except Exception as e:
|
||||
# If estimation fails, return None to use full attention
|
||||
print(f"XAttention estimate failed: {e}, falling back to full attention")
|
||||
return None
|
||||
|
||||
def compute_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute XAttention sparse attention for prefill.
|
||||
Compute XAttention sparse prefill attention.
|
||||
|
||||
Flow:
|
||||
1. Call estimate() to get sparse mask
|
||||
2. If mask is None or BSA unavailable, use full FlashAttention
|
||||
3. Otherwise, use block_sparse_attn_func with mask
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Current transformer layer index
|
||||
layer_id: Transformer layer index
|
||||
softmax_scale: Softmax scaling factor
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
seq_len = q.shape[0]
|
||||
num_heads = q.shape[1]
|
||||
head_dim = q.shape[2]
|
||||
num_kv_heads = k.shape[1]
|
||||
# If BSA is disabled, use full attention directly (skip estimation)
|
||||
if not self.use_bsa:
|
||||
return self._full_attention(q, k, v, softmax_scale)
|
||||
|
||||
# Use FlashAttention directly for CPU offload mode
|
||||
# FlashAttention supports GQA natively
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
# Step 1: Estimate sparse mask
|
||||
sparse_mask = self.estimate(q, k, layer_id)
|
||||
|
||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||
# Step 2: Compute attention
|
||||
if sparse_mask is None:
|
||||
# Estimation failed, fallback to full FlashAttention
|
||||
return self._full_attention(q, k, v, softmax_scale)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
softmax_scale=1.0 / math.sqrt(head_dim),
|
||||
causal=True,
|
||||
)
|
||||
# Use block sparse attention with mask
|
||||
return self._block_sparse_attention(q, k, v, sparse_mask, softmax_scale)
|
||||
|
||||
return attn_output
|
||||
|
||||
except Exception as e:
|
||||
# Fallback: PyTorch SDPA (supports GQA natively)
|
||||
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=None,
|
||||
is_causal=True,
|
||||
scale=1.0 / math.sqrt(head_dim)
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def _xattn_offload_prefill(
|
||||
def _block_sparse_attention(
|
||||
self,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
causal: bool = True,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
sparse_mask: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Simplified XAttention prefill for CPU offload mode.
|
||||
|
||||
Uses FlashAttention with full context since chunked estimation
|
||||
with full key_states requires special handling.
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||
_, _, k_len, _ = key_states.shape
|
||||
|
||||
# Use FlashAttention with full context
|
||||
# In offload mode, keys are already on CPU and loaded as needed
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
# Convert to [seq, heads, dim] format
|
||||
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
|
||||
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
|
||||
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=q_len,
|
||||
max_seqlen_k=k_len,
|
||||
softmax_scale=1.0 / math.sqrt(head_dim),
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
# Convert back to [batch, seq, heads, dim]
|
||||
attn_output = attn_output.unsqueeze(0).transpose(1, 2) # [1, q_len, num_heads, head_dim]
|
||||
|
||||
return attn_output
|
||||
|
||||
except Exception as e:
|
||||
# Final fallback: PyTorch SDPA
|
||||
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_states, key_states, value_states,
|
||||
attn_mask=None,
|
||||
is_causal=causal,
|
||||
scale=1.0 / math.sqrt(head_dim)
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def _xattn_prefill(
|
||||
self,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
stride: int,
|
||||
norm: float,
|
||||
threshold: float,
|
||||
block_size: int = 128,
|
||||
use_triton: bool = True,
|
||||
causal: bool = True,
|
||||
chunk_size: Optional[int] = None,
|
||||
keep_sink: bool = False,
|
||||
keep_recent: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
XAttention prefill implementation.
|
||||
Compute block sparse attention using MIT-HAN-LAB BSA library.
|
||||
|
||||
Args:
|
||||
query_states: [batch, num_heads, q_len, head_dim]
|
||||
key_states: [batch, num_heads, k_len, head_dim]
|
||||
value_states: [batch, num_heads, k_len, head_dim]
|
||||
... other params
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
sparse_mask: Block mask [batch, num_heads, q_blocks, k_blocks]
|
||||
softmax_scale: Softmax scaling factor
|
||||
|
||||
Returns:
|
||||
Attention output [batch, q_len, num_heads, head_dim]
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
batch_size, num_heads, k_len, head_dim = key_states.shape
|
||||
_, _, q_len, _ = query_states.shape
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
|
||||
# Auto-compute chunk_size if not specified
|
||||
if chunk_size is None:
|
||||
chunk_size = int(
|
||||
max(
|
||||
min(
|
||||
max(2048, 1 << (k_len - 1).bit_length()),
|
||||
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()),
|
||||
),
|
||||
2048,
|
||||
)
|
||||
)
|
||||
seq_len, num_heads, head_dim = q.shape
|
||||
num_kv_heads = k.shape[1]
|
||||
|
||||
# Phase 1: Estimate sparse pattern
|
||||
attn_sums, approx_simple_mask = self._xattn_estimate(
|
||||
query_states,
|
||||
key_states,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
norm=norm,
|
||||
threshold=threshold,
|
||||
chunk_size=chunk_size,
|
||||
use_triton=use_triton,
|
||||
causal=causal,
|
||||
keep_sink=keep_sink,
|
||||
keep_recent=keep_recent,
|
||||
)
|
||||
# Handle GQA: expand K/V to match Q heads
|
||||
if num_kv_heads != num_heads:
|
||||
repeat_factor = num_heads // num_kv_heads
|
||||
k = k.repeat_interleave(repeat_factor, dim=1)
|
||||
v = v.repeat_interleave(repeat_factor, dim=1)
|
||||
|
||||
# Phase 2: Block sparse attention
|
||||
# For now, use FlashAttention as fallback since block_sparse_attn_func may not be available
|
||||
attn_output = self._block_sparse_attention_fallback(
|
||||
query_states, key_states, value_states,
|
||||
approx_simple_mask, block_size, q_len, k_len
|
||||
# Cumulative sequence lengths (batch=1)
|
||||
cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
# Head mask type: 1 for all heads using block sparse
|
||||
head_mask_type = torch.ones(num_heads, dtype=torch.int32, device=q.device)
|
||||
|
||||
# Trim sparse_mask to actual block counts
|
||||
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
||||
block_mask = sparse_mask[:, :, :q_blocks, :k_blocks].contiguous()
|
||||
|
||||
# Call BSA
|
||||
attn_output = block_sparse_attn_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k,
|
||||
head_mask_type,
|
||||
None, # streaming_info (left_mask)
|
||||
block_mask,
|
||||
seq_len, seq_len,
|
||||
p_dropout=0.0,
|
||||
deterministic=True,
|
||||
softmax_scale=softmax_scale,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
def _xattn_estimate(
|
||||
def _full_attention(
|
||||
self,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
block_size: int,
|
||||
stride: int,
|
||||
norm: float = 1,
|
||||
softmax: bool = True,
|
||||
threshold: float = 0.9,
|
||||
chunk_size: int = 16384,
|
||||
use_triton: bool = True,
|
||||
causal: bool = True,
|
||||
keep_sink: bool = False,
|
||||
keep_recent: bool = False,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Estimate sparse attention pattern using chunked computation.
|
||||
Compute full causal attention using FlashAttention.
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
softmax_scale: Softmax scaling factor
|
||||
|
||||
Returns:
|
||||
attn_sums: [batch, heads, q_blocks, k_blocks] - importance scores
|
||||
simple_masks: [batch, heads, q_blocks, k_blocks] - boolean masks
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
batch_size, num_kv_head, k_len, head_dim = key_states.shape
|
||||
batch_size, num_q_head, q_len, head_dim = query_states.shape
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
|
||||
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
|
||||
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
|
||||
k_block_num = (k_len + k_num_to_pad) // block_size
|
||||
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
|
||||
q_block_num = (q_len + q_num_to_pad) // block_size
|
||||
seq_len = q.shape[0]
|
||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
# Pad inputs
|
||||
if k_num_to_pad > 0:
|
||||
pad_key_states = F.pad(key_states, (0, 0, 0, k_num_to_pad), value=0)
|
||||
else:
|
||||
pad_key_states = key_states
|
||||
if q_num_to_pad > 0:
|
||||
pad_query_states = F.pad(query_states, (0, 0, 0, q_num_to_pad), value=0)
|
||||
else:
|
||||
pad_query_states = query_states
|
||||
|
||||
reshaped_chunk_size = chunk_size // stride
|
||||
reshaped_block_size = block_size // stride
|
||||
k_reshaped_seq_len = (k_len + k_num_to_pad) // stride
|
||||
|
||||
attn_sum_list = []
|
||||
simple_mask_list = []
|
||||
|
||||
for chunk_idx in range(q_chunk_num):
|
||||
if use_triton:
|
||||
# Triton GEMM + Softmax
|
||||
attn_weights_slice = flat_group_gemm_fuse_reshape(
|
||||
pad_query_states[:, :, (chunk_idx * reshaped_chunk_size) * stride : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size) * stride, :],
|
||||
pad_key_states,
|
||||
stride,
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
|
||||
is_causal=causal,
|
||||
)
|
||||
|
||||
attn_sum = softmax_fuse_block_sum(
|
||||
attn_weights_slice,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
|
||||
k_reshaped_seq_len - (k_num_to_pad // stride),
|
||||
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
|
||||
is_causal=causal,
|
||||
)
|
||||
else:
|
||||
# PyTorch fallback
|
||||
chunk_size_actual = reshaped_chunk_size
|
||||
chunk_start = chunk_idx * chunk_size_actual
|
||||
chunk_end = chunk_start + chunk_size_actual
|
||||
|
||||
chunked_query = pad_query_states[:, :, chunk_start * stride:chunk_end * stride:stride, :]
|
||||
attn_weights_slice = torch.matmul(chunked_query, pad_key_states.transpose(2, 3))
|
||||
attn_weights_slice = attn_weights_slice / math.sqrt(head_dim) / stride / norm
|
||||
|
||||
if causal:
|
||||
causal_mask = torch.zeros((batch_size, num_q_head, chunk_size_actual, chunk_size_actual * k_chunk_num), device=key_states.device)
|
||||
causal_mask[:, :, :, -(k_num_to_pad // stride):] = float("-inf")
|
||||
# ... more causal mask logic ...
|
||||
attn_weights_slice = attn_weights_slice + causal_mask
|
||||
|
||||
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1, dtype=torch.float32)
|
||||
attn_sum = attn_weights_slice.view(batch_size, num_q_head, chunk_size_actual // reshaped_block_size, reshaped_block_size, -1).sum(dim=-1).sum(dim=-2)
|
||||
|
||||
# Find blocks based on threshold
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum,
|
||||
k_block_num - q_block_num + chunk_idx * (reshaped_chunk_size // reshaped_block_size),
|
||||
threshold,
|
||||
None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
attn_sum_list.append(attn_sum)
|
||||
simple_mask_list.append(simple_mask)
|
||||
|
||||
attn_sums = torch.cat(attn_sum_list, dim=-2)
|
||||
simple_masks = torch.cat(simple_mask_list, dim=-2)
|
||||
|
||||
# Apply causal mask to block masks
|
||||
if causal:
|
||||
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=key_states.device), diagonal=0),
|
||||
simple_masks[:, :, -q_block_num:, -q_block_num:],
|
||||
False,
|
||||
)
|
||||
|
||||
if keep_sink:
|
||||
simple_masks[:, :, 0, :] = True
|
||||
|
||||
if keep_recent:
|
||||
eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool)
|
||||
eye_matrix_expanded = eye_matrix.unsqueeze(0).unsqueeze(0).expand(1, num_q_head, q_block_num, q_block_num)
|
||||
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:]
|
||||
)
|
||||
|
||||
return attn_sums, simple_masks
|
||||
|
||||
def _block_sparse_attention_fallback(
|
||||
self,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
block_size: int,
|
||||
q_len: int,
|
||||
k_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fallback implementation using FlashAttention.
|
||||
|
||||
Since block_sparse_attn_func may not be available in all environments,
|
||||
this uses standard FlashAttention with full attention.
|
||||
"""
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
batch_size, num_heads, _, head_dim = query_states.shape
|
||||
|
||||
# Convert to [seq, heads, dim] format
|
||||
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
|
||||
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
|
||||
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=q_len,
|
||||
max_seqlen_k=k_len,
|
||||
softmax_scale=1.0 / math.sqrt(head_dim),
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Convert back to [batch, seq, heads, dim]
|
||||
attn_output = attn_output.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
return attn_output
|
||||
|
||||
except Exception as e:
|
||||
# Final fallback: PyTorch SDPA
|
||||
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_states, key_states, value_states,
|
||||
attn_mask=None,
|
||||
is_causal=True,
|
||||
scale=1.0 / math.sqrt(query_states.shape[-1])
|
||||
)
|
||||
return attn_output
|
||||
return flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset policy state (no state to reset for XAttention)."""
|
||||
@@ -461,4 +305,6 @@ class XAttentionPolicy(SparsePolicy):
|
||||
return (f"XAttentionPolicy("
|
||||
f"stride={self.stride}, "
|
||||
f"threshold={self.threshold}, "
|
||||
f"use_triton={self.use_triton})")
|
||||
f"block_size={self.block_size}, "
|
||||
f"use_triton={self.use_triton}, "
|
||||
f"use_bsa={self.use_bsa})")
|
||||
|
||||
@@ -98,10 +98,10 @@ class Attention(nn.Module):
|
||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||
elif context.sparse_prefill_policy is not None:
|
||||
# Sparse prefill (GPU-only) - delegate to policy
|
||||
o = context.sparse_prefill_policy.sparse_prefill_attention(
|
||||
q, k, v, self.layer_id
|
||||
elif context.attention_policy is not None:
|
||||
# Attention via policy (GPU-only) - delegate to policy
|
||||
o = context.attention_policy.compute_prefill(
|
||||
q, k, v, self.layer_id, softmax_scale=self.scale
|
||||
)
|
||||
else:
|
||||
o = flash_attn_varlen_func(q, k, v,
|
||||
|
||||
@@ -14,9 +14,9 @@ class Context:
|
||||
context_lens: torch.Tensor | None = None
|
||||
block_tables: torch.Tensor | None = None
|
||||
|
||||
# Sparse prefill attention support (GPU-only path)
|
||||
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
|
||||
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
|
||||
# Attention policy support (GPU-only path)
|
||||
# When set, uses policy.compute_prefill() instead of FlashAttention
|
||||
attention_policy: Any = None # AttentionPolicy instance
|
||||
|
||||
|
||||
_CONTEXT = Context()
|
||||
@@ -35,7 +35,7 @@ def set_context(
|
||||
slot_mapping=None,
|
||||
context_lens=None,
|
||||
block_tables=None,
|
||||
sparse_prefill_policy=None,
|
||||
attention_policy=None,
|
||||
):
|
||||
global _CONTEXT
|
||||
_CONTEXT = Context(
|
||||
@@ -47,7 +47,7 @@ def set_context(
|
||||
slot_mapping=slot_mapping,
|
||||
context_lens=context_lens,
|
||||
block_tables=block_tables,
|
||||
sparse_prefill_policy=sparse_prefill_policy,
|
||||
attention_policy=attention_policy,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -32,11 +32,14 @@ def run_needle_test(
|
||||
enable_cpu_offload: bool = False,
|
||||
enable_quest: bool = False,
|
||||
enable_minference: bool = False,
|
||||
enable_xattn: bool = False,
|
||||
sparse_topk: int = 8,
|
||||
sparse_threshold: int = 4,
|
||||
minference_budget: float = 0.3,
|
||||
minference_vertical: int = 1000,
|
||||
minference_slash: int = 6096,
|
||||
xattn_threshold: float = 0.9,
|
||||
xattn_use_bsa: bool = True,
|
||||
gpu_utilization: float = 0.9,
|
||||
enforce_eager: bool = True,
|
||||
verbose: bool = True,
|
||||
@@ -56,11 +59,14 @@ def run_needle_test(
|
||||
enable_cpu_offload: Enable CPU offload mode
|
||||
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
||||
enable_minference: Enable MInference sparse prefill (GPU-only)
|
||||
enable_xattn: Enable XAttention sparse prefill with BSA
|
||||
sparse_topk: Top-K blocks for Quest
|
||||
sparse_threshold: Apply sparse only when blocks > threshold
|
||||
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
|
||||
minference_vertical: Fixed vertical_size (only used when budget=None)
|
||||
minference_slash: Fixed slash_size (only used when budget=None)
|
||||
xattn_threshold: XAttention block selection threshold (0-1)
|
||||
xattn_use_bsa: Use Block Sparse Attention library
|
||||
gpu_utilization: GPU memory utilization fraction
|
||||
verbose: Print detailed output
|
||||
|
||||
@@ -68,7 +74,9 @@ def run_needle_test(
|
||||
True if test passed, False otherwise
|
||||
"""
|
||||
# Determine sparse policy
|
||||
if enable_minference:
|
||||
if enable_xattn:
|
||||
sparse_policy = SparsePolicyType.XATTN
|
||||
elif enable_minference:
|
||||
sparse_policy = SparsePolicyType.MINFERENCE
|
||||
elif enable_quest:
|
||||
sparse_policy = SparsePolicyType.QUEST
|
||||
@@ -94,6 +102,8 @@ def run_needle_test(
|
||||
print(f" MInference: adaptive (budget={minference_budget})")
|
||||
else:
|
||||
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
|
||||
if enable_xattn:
|
||||
print(f" XAttention: threshold={xattn_threshold}, use_bsa={xattn_use_bsa}")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# 1. Initialize LLM
|
||||
@@ -111,7 +121,7 @@ def run_needle_test(
|
||||
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
||||
|
||||
# Set sparse policy (can be used with or without offload)
|
||||
if enable_minference or enable_quest:
|
||||
if enable_minference or enable_quest or enable_xattn:
|
||||
llm_kwargs["sparse_policy"] = sparse_policy
|
||||
|
||||
# MInference params (works with both GPU-only and offload mode)
|
||||
@@ -120,6 +130,11 @@ def run_needle_test(
|
||||
llm_kwargs["minference_vertical_size"] = minference_vertical
|
||||
llm_kwargs["minference_slash_size"] = minference_slash
|
||||
|
||||
# XAttention params
|
||||
if enable_xattn:
|
||||
llm_kwargs["xattn_threshold"] = xattn_threshold
|
||||
llm_kwargs["xattn_use_bsa"] = xattn_use_bsa
|
||||
|
||||
llm = LLM(model_path, **llm_kwargs)
|
||||
|
||||
# 2. Generate needle prompt
|
||||
@@ -224,6 +239,11 @@ if __name__ == "__main__":
|
||||
action="store_true",
|
||||
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-xattn",
|
||||
action="store_true",
|
||||
help="Enable XAttention sparse prefill with Block Sparse Attention"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sparse-topk",
|
||||
type=int,
|
||||
@@ -254,6 +274,17 @@ if __name__ == "__main__":
|
||||
default=6096,
|
||||
help="Fixed slash_size (only used when budget=0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--xattn-threshold",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="XAttention block selection threshold (0-1, higher=more blocks)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--xattn-no-bsa",
|
||||
action="store_true",
|
||||
help="Disable Block Sparse Attention (use FlashAttention fallback)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-utilization",
|
||||
type=float,
|
||||
@@ -291,11 +322,14 @@ if __name__ == "__main__":
|
||||
enable_cpu_offload=args.enable_offload,
|
||||
enable_quest=args.enable_quest,
|
||||
enable_minference=args.enable_minference,
|
||||
enable_xattn=args.enable_xattn,
|
||||
sparse_topk=args.sparse_topk,
|
||||
sparse_threshold=args.sparse_threshold,
|
||||
minference_budget=minference_budget,
|
||||
minference_vertical=args.minference_vertical,
|
||||
minference_slash=args.minference_slash,
|
||||
xattn_threshold=args.xattn_threshold,
|
||||
xattn_use_bsa=not args.xattn_no_bsa,
|
||||
gpu_utilization=args.gpu_utilization,
|
||||
enforce_eager=enforce_eager,
|
||||
verbose=True,
|
||||
|
||||
Reference in New Issue
Block a user