[WIP] need refactor.

This commit is contained in:
Zijie Tian
2026-01-22 22:20:34 +08:00
parent 69b779e252
commit 5fb0f67295
11 changed files with 514 additions and 548 deletions

View File

@@ -62,6 +62,7 @@ class Config:
xattn_keep_sink: bool = False # Always keep first block (sink tokens) xattn_keep_sink: bool = False # Always keep first block (sink tokens)
xattn_keep_recent: bool = False # Always keep recent diagonal blocks xattn_keep_recent: bool = False # Always keep recent diagonal blocks
xattn_norm: float = 1.0 # Normalization factor for attention scores xattn_norm: float = 1.0 # Normalization factor for attention scores
xattn_use_bsa: bool = True # Use Block Sparse Attention library (requires installation)
def __post_init__(self): def __post_init__(self):
assert os.path.isdir(self.model) assert os.path.isdir(self.model)

View File

@@ -57,8 +57,8 @@ class ModelRunner:
load_model(self.model, config.model) load_model(self.model, config.model)
self.sampler = GreedySampler() self.sampler = GreedySampler()
# Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache) # Initialize attention_policy before warmup (will be configured in allocate_kv_cache)
self.sparse_prefill_policy = None self.attention_policy = None
#> Disable warmup for debugging #> Disable warmup for debugging
self.warmup_model() self.warmup_model()
@@ -178,38 +178,35 @@ class ModelRunner:
# Create KV cache manager using factory # Create KV cache manager using factory
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config) self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
# Create sparse prefill policy # Create attention policy (always, including FULL)
# This is used for both GPU-only and CPU offload modes when policy supports prefill # In layerwise offload mode, all attention goes through the policy
self.sparse_prefill_policy = None from nanovllm.kvcache.sparse import create_attention_policy
if config.sparse_policy != SparsePolicyType.FULL:
from nanovllm.kvcache.sparse import create_sparse_policy
# Get policy-specific parameters based on type # Get policy-specific parameters based on type
if config.sparse_policy == SparsePolicyType.XATTN: if config.sparse_policy == SparsePolicyType.XATTN:
policy_kwargs = { policy_kwargs = {
"stride": config.xattn_stride, "stride": config.xattn_stride,
"threshold": config.xattn_threshold, "threshold": config.xattn_threshold,
"chunk_size": config.xattn_chunk_size, "chunk_size": config.xattn_chunk_size,
"use_triton": config.xattn_use_triton, "use_triton": config.xattn_use_triton,
"keep_sink": config.xattn_keep_sink, "keep_sink": config.xattn_keep_sink,
"keep_recent": config.xattn_keep_recent, "keep_recent": config.xattn_keep_recent,
"norm": config.xattn_norm, "norm": config.xattn_norm,
} "use_bsa": config.xattn_use_bsa,
else: # MINFERENCE or others }
policy_kwargs = { elif config.sparse_policy == SparsePolicyType.MINFERENCE:
"vertical_size": config.minference_vertical_size, policy_kwargs = {
"slash_size": config.minference_slash_size, "vertical_size": config.minference_vertical_size,
"adaptive_budget": config.minference_adaptive_budget, "slash_size": config.minference_slash_size,
"num_sink_tokens": config.minference_num_sink_tokens, "adaptive_budget": config.minference_adaptive_budget,
"num_recent_diags": config.minference_num_recent_diags, "num_sink_tokens": config.minference_num_sink_tokens,
} "num_recent_diags": config.minference_num_recent_diags,
}
else: # FULL or QUEST
policy_kwargs = {}
policy = create_sparse_policy(config.sparse_policy, **policy_kwargs) self.attention_policy = create_attention_policy(config.sparse_policy, **policy_kwargs)
logger.info(f"Attention policy: {self.attention_policy}")
# Only use if policy supports sparse prefill
if policy.supports_prefill:
self.sparse_prefill_policy = policy
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}")
# Allocate cache through manager # Allocate cache through manager
self.kvcache_manager.allocate_cache( self.kvcache_manager.allocate_cache(
@@ -395,7 +392,7 @@ class ModelRunner:
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
slot_mapping, None, block_tables, slot_mapping, None, block_tables,
sparse_prefill_policy=self.sparse_prefill_policy) attention_policy=self.attention_policy)
return input_ids, positions return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]): def prepare_decode(self, seqs: list[Sequence]):
@@ -592,21 +589,11 @@ class ModelRunner:
# RoPE # RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k) q, k = layer.self_attn.rotary_emb(positions, q, k)
# Sparse or Full attention (uses k, v directly - before store!) # Compute attention using policy (uses k, v directly - before store!)
if self.sparse_prefill_policy is not None: attn_output = self.attention_policy.compute_prefill(
attn_output = self.sparse_prefill_policy.sparse_prefill_attention( q, k, v, layer_id,
q, k, v, layer_id softmax_scale=layer.self_attn.attn.scale,
) )
else:
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# O projection # O projection
attn_output = attn_output.view(total_tokens, -1) attn_output = attn_output.view(total_tokens, -1)
@@ -872,23 +859,11 @@ class ModelRunner:
# RoPE # RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k) q, k = layer.self_attn.rotary_emb(positions, q, k)
# Sparse or Full attention # Compute attention using policy
if self.sparse_prefill_policy is not None: attn_output = self.attention_policy.compute_prefill(
# MInference or other sparse prefill policy q, k, v, layer_id,
attn_output = self.sparse_prefill_policy.sparse_prefill_attention( softmax_scale=layer.self_attn.attn.scale,
q, k, v, layer_id )
)
else:
# Full attention using FlashAttention
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# O projection # O projection
attn_output = attn_output.view(total_tokens, -1) attn_output = attn_output.view(total_tokens, -1)

View File

@@ -1,49 +1,56 @@
""" """
Sparse Attention Policy module. Attention Policy module for layerwise offload mode.
Provides pluggable policies for selecting which KV blocks to load Provides pluggable policies for attention computation:
during chunked attention with CPU offload. - FullAttentionPolicy: Standard FlashAttention (no sparsity)
- XAttentionPolicy: Sparse prefill using XAttention algorithm
- MInferencePolicy: MInference sparse attention
- QuestPolicy: Quest block selection (for chunked offload)
Usage: Usage:
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType
# Create policy using factory function # Create policy using factory function
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8) policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
# Use policy for attention
attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
# Or create custom policy # Or create custom policy
class MyPolicy(SparsePolicy): class MyPolicy(AttentionPolicy):
supports_prefill = True supports_prefill = True
supports_decode = True supports_decode = True
def select_blocks(self, available_blocks, ctx): def compute_prefill(self, q, k, v, layer_id, softmax_scale):
return available_blocks[:5] # Just first 5 blocks # Custom attention computation
...
""" """
from nanovllm.config import SparsePolicyType from nanovllm.config import SparsePolicyType
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.minference import MInferencePolicy from nanovllm.kvcache.sparse.minference import MInferencePolicy
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy: def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy:
""" """
Create a sparse policy instance from an enum type. Create an attention policy instance from an enum type.
The returned policy is not yet initialized. Call policy.initialize() All attention (including full attention) goes through a policy in layerwise
or let the framework call it during KV cache allocation. offload mode. The policy is responsible for computing prefill/decode attention.
Args: Args:
policy_type: SparsePolicyType enum value policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST)
**kwargs: Policy-specific configuration options **kwargs: Policy-specific configuration options
Returns: Returns:
SparsePolicy instance (not initialized) AttentionPolicy instance
Example: Example:
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4) policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
policy.initialize(num_layers=28, num_kv_heads=8, ...) attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
""" """
if policy_type == SparsePolicyType.FULL: if policy_type == SparsePolicyType.FULL:
return FullAttentionPolicy() return FullAttentionPolicy()
@@ -75,21 +82,32 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
keep_sink=kwargs.get("keep_sink", False), keep_sink=kwargs.get("keep_sink", False),
keep_recent=kwargs.get("keep_recent", False), keep_recent=kwargs.get("keep_recent", False),
norm=kwargs.get("norm", 1.0), norm=kwargs.get("norm", 1.0),
use_bsa=kwargs.get("use_bsa", True),
) )
else: else:
raise ValueError(f"Unknown policy type: {policy_type}") raise ValueError(f"Unknown policy type: {policy_type}")
# Backward compatibility alias
create_sparse_policy = create_attention_policy
__all__ = [ __all__ = [
# New interface
"AttentionPolicy",
"create_attention_policy",
# Backward compatibility
"SparsePolicy", "SparsePolicy",
"create_sparse_policy",
# Common types
"PolicyContext", "PolicyContext",
"SparsePolicyType", "SparsePolicyType",
# Policy implementations
"FullAttentionPolicy", "FullAttentionPolicy",
"QuestPolicy", "QuestPolicy",
"QuestConfig", "QuestConfig",
"BlockMetadataManager", "BlockMetadataManager",
"MInferencePolicy", "MInferencePolicy",
"XAttentionPolicy", "XAttentionPolicy",
"create_sparse_policy",
] ]

View File

@@ -1,20 +1,21 @@
""" """
Full attention policy - loads all blocks (no sparsity). Full attention policy - standard FlashAttention without sparsity.
This serves as a baseline and default policy when sparse This serves as a baseline and default policy when sparse
attention is not needed. attention is not needed.
""" """
from typing import List from typing import Optional
from .policy import SparsePolicy, PolicyContext import torch
from .policy import AttentionPolicy
class FullAttentionPolicy(SparsePolicy): class FullAttentionPolicy(AttentionPolicy):
""" """
Full attention policy that loads all available blocks. Full attention policy using FlashAttention (no sparsity).
This is the default behavior with no sparsity - all previous This is the default behavior with standard causal attention.
KV cache blocks are loaded for each query chunk. All tokens attend to all previous tokens.
Use this as: Use this as:
- A baseline for comparing sparse policies - A baseline for comparing sparse policies
@@ -25,15 +26,55 @@ class FullAttentionPolicy(SparsePolicy):
# Full attention supports both prefill and decode # Full attention supports both prefill and decode
supports_prefill = True supports_prefill = True
supports_decode = True supports_decode = True
requires_block_selection = False # Load all blocks, no selective loading
def select_blocks( def estimate(
self, self,
available_blocks: List[int], q: torch.Tensor,
ctx: PolicyContext, k: torch.Tensor,
) -> List[int]: layer_id: int,
"""Return all blocks - no sparsity.""" ) -> Optional[torch.Tensor]:
return available_blocks """
Full attention - no sparse mask needed.
Returns None to indicate full attention should be used.
"""
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute full causal attention using FlashAttention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def __repr__(self) -> str: def __repr__(self) -> str:
return "FullAttentionPolicy()" return "FullAttentionPolicy()"

View File

@@ -10,10 +10,10 @@ from typing import List, Tuple, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import AttentionPolicy, PolicyContext
class MInferencePolicy(SparsePolicy): class MInferencePolicy(AttentionPolicy):
""" """
MInference sparse prefill policy using vertical + slash pattern. MInference sparse prefill policy using vertical + slash pattern.
@@ -347,6 +347,33 @@ class MInferencePolicy(SparsePolicy):
return o return o
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute MInference sparse prefill attention.
This is the new unified interface for attention policies.
Delegates to sparse_prefill_attention (ignores softmax_scale as MInference
computes it internally from head_dim).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (unused, computed internally)
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
return self.sparse_prefill_attention(q, k, v, layer_id)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"MInferencePolicy(" return (f"MInferencePolicy("
f"adaptive_budget={self.adaptive_budget}, " f"adaptive_budget={self.adaptive_budget}, "

View File

@@ -1,13 +1,18 @@
""" """
Base class for sparse attention policies. Base class for attention policies in layerwise offload mode.
Sparse attention policies determine which KV cache blocks to load AttentionPolicy defines the interface for all attention computation,
from CPU for each query chunk during chunked attention computation. including full attention and sparse attention methods like XAttention.
Key methods:
- estimate(): Compute sparse attention mask (optional, returns None for full attention)
- compute_prefill(): Compute prefill attention
- compute_decode(): Compute decode attention (default implementation provided)
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Any from typing import List, Optional, Tuple
import torch import torch
# Import SparsePolicyType from config to avoid circular imports # Import SparsePolicyType from config to avoid circular imports
@@ -17,10 +22,10 @@ from nanovllm.config import SparsePolicyType
@dataclass @dataclass
class PolicyContext: class PolicyContext:
""" """
Context passed to sparse policy for block selection. Context passed to attention policy for block selection.
This dataclass contains all information needed by a sparse policy This dataclass contains all information needed by an attention policy
to decide which blocks to load for the current query chunk. for sparse estimation and attention computation.
""" """
query_chunk_idx: int query_chunk_idx: int
@@ -49,40 +54,41 @@ class PolicyContext:
"""Total KV sequence length so far (for reference).""" """Total KV sequence length so far (for reference)."""
class SparsePolicy(ABC): class AttentionPolicy(ABC):
""" """
Abstract base class for sparse attention policies. Base class for attention policies in layerwise offload mode.
Subclass this and implement select_blocks() to create custom All attention computation goes through a policy, including both
sparse attention patterns. The policy receives context about full attention and sparse attention methods.
the current query chunk and returns which KV blocks to load.
The policy interface is designed for layerwise offload where:
- The entire KV cache for a layer is on GPU during computation
- No need for block loading from CPU during attention
- estimate() returns a sparse mask (or None for full attention)
- compute_prefill()/compute_decode() perform the actual attention
Attributes: Attributes:
supports_prefill: Whether this policy can be used for prefill phase. supports_prefill: Whether this policy can be used for prefill phase.
supports_decode: Whether this policy can be used for decode phase. supports_decode: Whether this policy can be used for decode phase.
Example: Example:
class MySparsePolicy(SparsePolicy): class MyPolicy(AttentionPolicy):
supports_prefill = False # decode-only policy supports_prefill = True
supports_decode = True supports_decode = True
def select_blocks(self, available_blocks, ctx): def estimate(self, q, k, layer_id):
# Load first block and last 2 blocks # Return sparse mask or None
if len(available_blocks) <= 3: return None
return available_blocks
return [available_blocks[0]] + available_blocks[-2:] def compute_prefill(self, q, k, v, layer_id, softmax_scale):
# Compute attention
return flash_attn_varlen_func(q, k, v, ...)
""" """
# Compatibility flags - override in subclasses # Compatibility flags - override in subclasses
supports_prefill: bool = True supports_prefill: bool = True
supports_decode: bool = True supports_decode: bool = True
# Whether this policy requires selective block loading during decode
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
# Example: MInference=False (only affects attention), Quest=True (affects load)
requires_block_selection: bool = False
def initialize( def initialize(
self, self,
num_layers: int, num_layers: int,
@@ -96,7 +102,7 @@ class SparsePolicy(ABC):
Initialize policy resources. Initialize policy resources.
Called by the framework after KV cache is allocated. Override this Called by the framework after KV cache is allocated. Override this
to create metadata structures (e.g., BlockMetadataManager for Quest). to create metadata structures or pre-allocate buffers.
Default implementation does nothing. Default implementation does nothing.
Args: Args:
@@ -109,76 +115,98 @@ class SparsePolicy(ABC):
""" """
pass pass
@abstractmethod def estimate(
def select_blocks(
self, self,
available_blocks: List[int], q: torch.Tensor,
ctx: PolicyContext, k: torch.Tensor,
) -> List[int]: layer_id: int,
) -> Optional[torch.Tensor]:
""" """
Select which KV blocks to load for the current query chunk. Estimate sparse attention mask.
This is the core method that defines the sparse attention pattern. For sparse policies (e.g., XAttention), computes block-level importance
The returned blocks will be loaded from CPU to GPU for attention and returns a boolean mask indicating which blocks to attend.
computation against the current query chunk. For full attention policy, returns None.
This corresponds to xattn_estimate() in COMPASS.
Args: Args:
available_blocks: List of CPU block IDs that contain KV cache q: Query tensor [seq_len, num_heads, head_dim]
from previous chunks. These are ordered by k: Key tensor [seq_len, num_kv_heads, head_dim]
their position in the sequence. layer_id: Transformer layer index
ctx: PolicyContext with information about the current query
chunk, layer, phase (prefill/decode), etc.
Returns: Returns:
List of block IDs to load (must be a subset of available_blocks). sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
The order may affect performance (sequential access is faster). or None for full attention
Returning [] means no previous blocks will be loaded.
""" """
pass return None
def on_prefill_offload( @abstractmethod
def compute_prefill(
self, self,
cpu_block_id: int, q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int, layer_id: int,
k_cache: torch.Tensor, softmax_scale: float,
num_valid_tokens: int, ) -> torch.Tensor:
) -> None:
""" """
Hook called when a block is offloaded during prefill phase. Compute prefill attention.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU. The entire KV cache for this layer is on GPU. Compute attention
Override this to collect metadata about blocks (e.g., min/max keys between Q and K/V, optionally using sparse mask from estimate().
for Quest-style selection). Default implementation does nothing.
Args: Args:
cpu_block_id: The CPU block ID that will be written q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU) softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
num_valid_tokens: Number of valid tokens in this block
Returns:
Attention output [seq_len, num_heads, head_dim]
""" """
pass pass
def on_decode_offload( def compute_decode(
self, self,
cpu_block_id: int, q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int, layer_id: int,
k_cache: torch.Tensor, softmax_scale: float,
num_valid_tokens: int, ) -> torch.Tensor:
) -> None:
""" """
Hook called when a block is offloaded during decode phase. Compute decode attention.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU. KV is provided from ring buffer, containing prefill tokens + decoded tokens.
Override this to update metadata about blocks. Default implementation Default implementation uses FlashAttention.
does nothing.
Args: Args:
cpu_block_id: The CPU block ID that will be written q: Query tensor [1, num_heads, head_dim]
k: Key tensor [context_len+1, num_kv_heads, head_dim]
v: Value tensor [context_len+1, num_kv_heads, head_dim]
layer_id: Transformer layer index layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU) softmax_scale: Softmax scaling factor
num_valid_tokens: Number of valid tokens in this block
Returns:
Attention output [1, num_heads, head_dim]
""" """
pass from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None: def reset(self) -> None:
""" """
@@ -189,32 +217,9 @@ class SparsePolicy(ABC):
""" """
pass pass
def sparse_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
) -> torch.Tensor:
"""
Compute sparse attention for prefill phase.
This method is called when supports_prefill=True and the policy
is used for GPU-only sparse prefill (no CPU offload).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current transformer layer index
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
"Set supports_prefill=False or implement this method."
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}()" return f"{self.__class__.__name__}()"
# Backward compatibility alias
SparsePolicy = AttentionPolicy

View File

@@ -11,7 +11,7 @@ import logging
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from .policy import SparsePolicy, PolicyContext from .policy import AttentionPolicy, PolicyContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -137,7 +137,7 @@ class QuestConfig:
"""Always include this many recent blocks (last N blocks), in addition to Top-K.""" """Always include this many recent blocks (last N blocks), in addition to Top-K."""
class QuestPolicy(SparsePolicy): class QuestPolicy(AttentionPolicy):
""" """
Quest-style Top-K block selection using min/max key bounds. Quest-style Top-K block selection using min/max key bounds.
@@ -317,6 +317,25 @@ class QuestPolicy(SparsePolicy):
if self.metadata is not None: if self.metadata is not None:
self.metadata.reset() self.metadata.reset()
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Quest does not support prefill - raises error.
Quest is a decode-only policy for selective block loading.
For prefill, use FullAttentionPolicy or XAttentionPolicy.
"""
raise NotImplementedError(
"QuestPolicy does not support prefill. "
"Use FullAttentionPolicy or XAttentionPolicy for prefill."
)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"QuestPolicy(topk={self.config.topk_blocks}, " f"QuestPolicy(topk={self.config.topk_blocks}, "

View File

@@ -4,48 +4,56 @@ XAttention sparse attention policy for nano-vllm.
Implements the XAttention algorithm from COMPASS, using chunked estimation Implements the XAttention algorithm from COMPASS, using chunked estimation
and block sparse attention for efficient long-context inference. and block sparse attention for efficient long-context inference.
Architecture:
XAttention = Estimate (Triton) + Compute (BSA)
- Estimate: xattn_estimate() computes block-level importance scores
- Compute: block_sparse_attn_func() executes sparse attention
Reference: COMPASS/compass/src/Xattention.py Reference: COMPASS/compass/src/Xattention.py
""" """
import math import math
from typing import List, Optional from typing import Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import AttentionPolicy
from nanovllm.kvcache.sparse.kernels import (
flat_group_gemm_fuse_reshape, # BSA block size is fixed at 128 (hardcoded in block_sparse_attn)
softmax_fuse_block_sum, BSA_BLOCK_SIZE = 128
)
from nanovllm.kvcache.sparse.utils import find_blocks_chunked
class XAttentionPolicy(SparsePolicy): class XAttentionPolicy(AttentionPolicy):
""" """
XAttention sparse prefill policy using chunked estimation + block sparse attention. XAttention sparse prefill policy using chunked estimation + block sparse attention.
This policy estimates sparse attention patterns by: This policy estimates sparse attention patterns by:
1. Chunked QK computation using Triton kernels 1. Chunked QK computation using Triton kernels (via nanovllm.ops.xattn)
2. Block-wise softmax with importance scores 2. Block-wise softmax with importance scores
3. Block selection based on threshold 3. Block selection based on threshold
4. Block sparse attention computation 4. Block sparse attention computation using MIT-HAN-LAB BSA library
The key method is estimate() which calls xattn_estimate() from nanovllm.ops
to compute the sparse attention mask.
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.) Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
BSA library: https://github.com/mit-han-lab/Block-Sparse-Attention
""" """
supports_prefill = True supports_prefill = True
supports_decode = False # XAttention is prefill-only supports_decode = True # Uses default FlashAttention for decode
requires_block_selection = False # Only affects attention computation
def __init__( def __init__(
self, self,
stride: int = 8, stride: int = 8,
threshold: float = 0.9, threshold: float = 0.9,
chunk_size: Optional[int] = None, block_size: int = 128,
chunk_size: int = 16384,
use_triton: bool = True, use_triton: bool = True,
keep_sink: bool = False, keep_sink: bool = False,
keep_recent: bool = False, keep_recent: bool = False,
norm: float = 1.0, norm: float = 1.0,
use_bsa: bool = True,
): ):
""" """
Initialize XAttention policy. Initialize XAttention policy.
@@ -53,19 +61,28 @@ class XAttentionPolicy(SparsePolicy):
Args: Args:
stride: Stride for reorganizing Q/K (default: 8) stride: Stride for reorganizing Q/K (default: 8)
threshold: Block selection threshold, 0-1 (default: 0.9) threshold: Block selection threshold, 0-1 (default: 0.9)
chunk_size: Chunk size for estimation (auto if None) block_size: Block size for sparse attention (default: 128, must match BSA)
chunk_size: Chunk size for estimation (default: 16384)
use_triton: Use Triton kernels (requires SM 80+) use_triton: Use Triton kernels (requires SM 80+)
keep_sink: Always keep first block (sink tokens) keep_sink: Always keep first block (sink tokens)
keep_recent: Always keep recent diagonal blocks keep_recent: Always keep recent diagonal blocks
norm: Normalization factor for attention scores norm: Normalization factor for attention scores
use_bsa: Use Block Sparse Attention library (default: True)
""" """
self.stride = stride self.stride = stride
self.threshold = threshold self.threshold = threshold
self.block_size = block_size
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.use_triton = use_triton self.use_triton = use_triton
self.keep_sink = keep_sink self.keep_sink = keep_sink
self.keep_recent = keep_recent self.keep_recent = keep_recent
self.norm = norm self.norm = norm
self.use_bsa = use_bsa
# BSA requires block_size = 128
if self.use_bsa and self.block_size != BSA_BLOCK_SIZE:
print(f"XAttention: BSA requires block_size=128, adjusting from {self.block_size}")
self.block_size = BSA_BLOCK_SIZE
# Check Triton availability # Check Triton availability
if self.use_triton: if self.use_triton:
@@ -79,379 +96,206 @@ class XAttentionPolicy(SparsePolicy):
self.use_triton = False self.use_triton = False
print("XAttention: Triton not available. Falling back to PyTorch.") print("XAttention: Triton not available. Falling back to PyTorch.")
def select_blocks( # Check BSA availability
if self.use_bsa:
try:
from block_sparse_attn import block_sparse_attn_func
except ImportError:
self.use_bsa = False
print("XAttention: block_sparse_attn not available. Falling back to FlashAttention.")
def estimate(
self, self,
available_blocks: List[int], q: torch.Tensor,
ctx: PolicyContext, k: torch.Tensor,
) -> List[int]: layer_id: int,
) -> Optional[torch.Tensor]:
""" """
Select blocks for decode phase. Estimate sparse attention mask using XAttention algorithm.
XAttention is prefill-only, so this method is only used as a fallback. Calls xattn_estimate() from nanovllm.ops.xattn to compute block-level
Returns all available blocks by default. importance scores and generate a sparse boolean mask.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
or None if estimation fails (fallback to full attention)
""" """
# XAttention is prefill-only, but we need to implement this abstract method try:
# Since requires_block_selection=False, this won't be called for loading from nanovllm.ops.xattn import xattn_estimate
return available_blocks
def sparse_prefill_attention( seq_len, num_heads, head_dim = q.shape
num_kv_heads = k.shape[1]
# Convert to [batch, heads, seq, dim] format expected by xattn_estimate
q_bhsd = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
k_bhsd = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, seq_len, head_dim]
# Handle GQA: expand k to match q heads for estimation
if num_kv_heads != num_heads:
# GQA: expand k by repeating
repeat_factor = num_heads // num_kv_heads
k_bhsd = k_bhsd.repeat(1, repeat_factor, 1, 1)
# Call xattn_estimate
attn_sums, sparse_mask = xattn_estimate(
q_bhsd, k_bhsd,
block_size=self.block_size,
stride=self.stride,
norm=self.norm,
threshold=self.threshold,
chunk_size=self.chunk_size,
use_triton=self.use_triton,
causal=True,
keep_sink=self.keep_sink,
keep_recent=self.keep_recent,
)
return sparse_mask
except Exception as e:
# If estimation fails, return None to use full attention
print(f"XAttention estimate failed: {e}, falling back to full attention")
return None
def compute_prefill(
self, self,
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
layer_id: int, layer_id: int,
softmax_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute XAttention sparse attention for prefill. Compute XAttention sparse prefill attention.
Flow:
1. Call estimate() to get sparse mask
2. If mask is None or BSA unavailable, use full FlashAttention
3. Otherwise, use block_sparse_attn_func with mask
Args: Args:
q: Query tensor [seq_len, num_heads, head_dim] q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim] v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current transformer layer index layer_id: Transformer layer index
softmax_scale: Softmax scaling factor
Returns: Returns:
Attention output [seq_len, num_heads, head_dim] Attention output [seq_len, num_heads, head_dim]
""" """
seq_len = q.shape[0] # If BSA is disabled, use full attention directly (skip estimation)
num_heads = q.shape[1] if not self.use_bsa:
head_dim = q.shape[2] return self._full_attention(q, k, v, softmax_scale)
num_kv_heads = k.shape[1]
# Use FlashAttention directly for CPU offload mode # Step 1: Estimate sparse mask
# FlashAttention supports GQA natively sparse_mask = self.estimate(q, k, layer_id)
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) # Step 2: Compute attention
if sparse_mask is None:
# Estimation failed, fallback to full FlashAttention
return self._full_attention(q, k, v, softmax_scale)
attn_output = flash_attn_varlen_func( # Use block sparse attention with mask
q, k, v, return self._block_sparse_attention(q, k, v, sparse_mask, softmax_scale)
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=1.0 / math.sqrt(head_dim),
causal=True,
)
return attn_output def _block_sparse_attention(
except Exception as e:
# Fallback: PyTorch SDPA (supports GQA natively)
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
is_causal=True,
scale=1.0 / math.sqrt(head_dim)
)
return attn_output
def _xattn_offload_prefill(
self, self,
query_states: torch.Tensor, q: torch.Tensor,
key_states: torch.Tensor, k: torch.Tensor,
value_states: torch.Tensor, v: torch.Tensor,
causal: bool = True, sparse_mask: torch.Tensor,
softmax_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Simplified XAttention prefill for CPU offload mode. Compute block sparse attention using MIT-HAN-LAB BSA library.
Uses FlashAttention with full context since chunked estimation
with full key_states requires special handling.
"""
batch_size, num_heads, q_len, head_dim = query_states.shape
_, _, k_len, _ = key_states.shape
# Use FlashAttention with full context
# In offload mode, keys are already on CPU and loaded as needed
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Convert to [seq, heads, dim] format
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=k_len,
softmax_scale=1.0 / math.sqrt(head_dim),
causal=causal,
)
# Convert back to [batch, seq, heads, dim]
attn_output = attn_output.unsqueeze(0).transpose(1, 2) # [1, q_len, num_heads, head_dim]
return attn_output
except Exception as e:
# Final fallback: PyTorch SDPA
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=None,
is_causal=causal,
scale=1.0 / math.sqrt(head_dim)
)
return attn_output
def _xattn_prefill(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
stride: int,
norm: float,
threshold: float,
block_size: int = 128,
use_triton: bool = True,
causal: bool = True,
chunk_size: Optional[int] = None,
keep_sink: bool = False,
keep_recent: bool = False,
) -> torch.Tensor:
"""
XAttention prefill implementation.
Args: Args:
query_states: [batch, num_heads, q_len, head_dim] q: Query tensor [seq_len, num_heads, head_dim]
key_states: [batch, num_heads, k_len, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim]
value_states: [batch, num_heads, k_len, head_dim] v: Value tensor [seq_len, num_kv_heads, head_dim]
... other params sparse_mask: Block mask [batch, num_heads, q_blocks, k_blocks]
softmax_scale: Softmax scaling factor
Returns: Returns:
Attention output [batch, q_len, num_heads, head_dim] Attention output [seq_len, num_heads, head_dim]
""" """
batch_size, num_heads, k_len, head_dim = key_states.shape from block_sparse_attn import block_sparse_attn_func
_, _, q_len, _ = query_states.shape
# Auto-compute chunk_size if not specified seq_len, num_heads, head_dim = q.shape
if chunk_size is None: num_kv_heads = k.shape[1]
chunk_size = int(
max(
min(
max(2048, 1 << (k_len - 1).bit_length()),
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()),
),
2048,
)
)
# Phase 1: Estimate sparse pattern # Handle GQA: expand K/V to match Q heads
attn_sums, approx_simple_mask = self._xattn_estimate( if num_kv_heads != num_heads:
query_states, repeat_factor = num_heads // num_kv_heads
key_states, k = k.repeat_interleave(repeat_factor, dim=1)
block_size=block_size, v = v.repeat_interleave(repeat_factor, dim=1)
stride=stride,
norm=norm,
threshold=threshold,
chunk_size=chunk_size,
use_triton=use_triton,
causal=causal,
keep_sink=keep_sink,
keep_recent=keep_recent,
)
# Phase 2: Block sparse attention # Cumulative sequence lengths (batch=1)
# For now, use FlashAttention as fallback since block_sparse_attn_func may not be available cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
attn_output = self._block_sparse_attention_fallback( cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
query_states, key_states, value_states,
approx_simple_mask, block_size, q_len, k_len # Head mask type: 1 for all heads using block sparse
head_mask_type = torch.ones(num_heads, dtype=torch.int32, device=q.device)
# Trim sparse_mask to actual block counts
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
block_mask = sparse_mask[:, :, :q_blocks, :k_blocks].contiguous()
# Call BSA
attn_output = block_sparse_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
None, # streaming_info (left_mask)
block_mask,
seq_len, seq_len,
p_dropout=0.0,
deterministic=True,
softmax_scale=softmax_scale,
is_causal=True,
) )
return attn_output return attn_output
def _xattn_estimate( def _full_attention(
self, self,
query_states: torch.Tensor, q: torch.Tensor,
key_states: torch.Tensor, k: torch.Tensor,
block_size: int, v: torch.Tensor,
stride: int, softmax_scale: float,
norm: float = 1,
softmax: bool = True,
threshold: float = 0.9,
chunk_size: int = 16384,
use_triton: bool = True,
causal: bool = True,
keep_sink: bool = False,
keep_recent: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Estimate sparse attention pattern using chunked computation. Compute full causal attention using FlashAttention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
softmax_scale: Softmax scaling factor
Returns: Returns:
attn_sums: [batch, heads, q_blocks, k_blocks] - importance scores Attention output [seq_len, num_heads, head_dim]
simple_masks: [batch, heads, q_blocks, k_blocks] - boolean masks
""" """
batch_size, num_kv_head, k_len, head_dim = key_states.shape from flash_attn.flash_attn_interface import flash_attn_varlen_func
batch_size, num_q_head, q_len, head_dim = query_states.shape
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len seq_len = q.shape[0]
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
k_block_num = (k_len + k_num_to_pad) // block_size
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
q_block_num = (q_len + q_num_to_pad) // block_size
# Pad inputs return flash_attn_varlen_func(
if k_num_to_pad > 0: q, k, v,
pad_key_states = F.pad(key_states, (0, 0, 0, k_num_to_pad), value=0) cu_seqlens_q=cu_seqlens,
else: cu_seqlens_k=cu_seqlens,
pad_key_states = key_states max_seqlen_q=seq_len,
if q_num_to_pad > 0: max_seqlen_k=seq_len,
pad_query_states = F.pad(query_states, (0, 0, 0, q_num_to_pad), value=0) softmax_scale=softmax_scale,
else: causal=True,
pad_query_states = query_states )
reshaped_chunk_size = chunk_size // stride
reshaped_block_size = block_size // stride
k_reshaped_seq_len = (k_len + k_num_to_pad) // stride
attn_sum_list = []
simple_mask_list = []
for chunk_idx in range(q_chunk_num):
if use_triton:
# Triton GEMM + Softmax
attn_weights_slice = flat_group_gemm_fuse_reshape(
pad_query_states[:, :, (chunk_idx * reshaped_chunk_size) * stride : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size) * stride, :],
pad_key_states,
stride,
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
is_causal=causal,
)
attn_sum = softmax_fuse_block_sum(
attn_weights_slice,
reshaped_block_size,
min(4096, reshaped_block_size),
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
k_reshaped_seq_len - (k_num_to_pad // stride),
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
is_causal=causal,
)
else:
# PyTorch fallback
chunk_size_actual = reshaped_chunk_size
chunk_start = chunk_idx * chunk_size_actual
chunk_end = chunk_start + chunk_size_actual
chunked_query = pad_query_states[:, :, chunk_start * stride:chunk_end * stride:stride, :]
attn_weights_slice = torch.matmul(chunked_query, pad_key_states.transpose(2, 3))
attn_weights_slice = attn_weights_slice / math.sqrt(head_dim) / stride / norm
if causal:
causal_mask = torch.zeros((batch_size, num_q_head, chunk_size_actual, chunk_size_actual * k_chunk_num), device=key_states.device)
causal_mask[:, :, :, -(k_num_to_pad // stride):] = float("-inf")
# ... more causal mask logic ...
attn_weights_slice = attn_weights_slice + causal_mask
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1, dtype=torch.float32)
attn_sum = attn_weights_slice.view(batch_size, num_q_head, chunk_size_actual // reshaped_block_size, reshaped_block_size, -1).sum(dim=-1).sum(dim=-2)
# Find blocks based on threshold
simple_mask = find_blocks_chunked(
attn_sum,
k_block_num - q_block_num + chunk_idx * (reshaped_chunk_size // reshaped_block_size),
threshold,
None,
decoding=False,
mode="prefill",
causal=causal,
)
attn_sum_list.append(attn_sum)
simple_mask_list.append(simple_mask)
attn_sums = torch.cat(attn_sum_list, dim=-2)
simple_masks = torch.cat(simple_mask_list, dim=-2)
# Apply causal mask to block masks
if causal:
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=key_states.device), diagonal=0),
simple_masks[:, :, -q_block_num:, -q_block_num:],
False,
)
if keep_sink:
simple_masks[:, :, 0, :] = True
if keep_recent:
eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool)
eye_matrix_expanded = eye_matrix.unsqueeze(0).unsqueeze(0).expand(1, num_q_head, q_block_num, q_block_num)
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:]
)
return attn_sums, simple_masks
def _block_sparse_attention_fallback(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
mask: torch.Tensor,
block_size: int,
q_len: int,
k_len: int,
) -> torch.Tensor:
"""
Fallback implementation using FlashAttention.
Since block_sparse_attn_func may not be available in all environments,
this uses standard FlashAttention with full attention.
"""
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
batch_size, num_heads, _, head_dim = query_states.shape
# Convert to [seq, heads, dim] format
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=k_len,
softmax_scale=1.0 / math.sqrt(head_dim),
causal=True,
)
# Convert back to [batch, seq, heads, dim]
attn_output = attn_output.unsqueeze(0).transpose(1, 2)
return attn_output
except Exception as e:
# Final fallback: PyTorch SDPA
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=None,
is_causal=True,
scale=1.0 / math.sqrt(query_states.shape[-1])
)
return attn_output
def reset(self) -> None: def reset(self) -> None:
"""Reset policy state (no state to reset for XAttention).""" """Reset policy state (no state to reset for XAttention)."""
@@ -461,4 +305,6 @@ class XAttentionPolicy(SparsePolicy):
return (f"XAttentionPolicy(" return (f"XAttentionPolicy("
f"stride={self.stride}, " f"stride={self.stride}, "
f"threshold={self.threshold}, " f"threshold={self.threshold}, "
f"use_triton={self.use_triton})") f"block_size={self.block_size}, "
f"use_triton={self.use_triton}, "
f"use_bsa={self.use_bsa})")

View File

@@ -98,10 +98,10 @@ class Attention(nn.Module):
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables) softmax_scale=self.scale, causal=True, block_table=context.block_tables)
elif context.sparse_prefill_policy is not None: elif context.attention_policy is not None:
# Sparse prefill (GPU-only) - delegate to policy # Attention via policy (GPU-only) - delegate to policy
o = context.sparse_prefill_policy.sparse_prefill_attention( o = context.attention_policy.compute_prefill(
q, k, v, self.layer_id q, k, v, self.layer_id, softmax_scale=self.scale
) )
else: else:
o = flash_attn_varlen_func(q, k, v, o = flash_attn_varlen_func(q, k, v,

View File

@@ -14,9 +14,9 @@ class Context:
context_lens: torch.Tensor | None = None context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None block_tables: torch.Tensor | None = None
# Sparse prefill attention support (GPU-only path) # Attention policy support (GPU-only path)
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention # When set, uses policy.compute_prefill() instead of FlashAttention
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True attention_policy: Any = None # AttentionPolicy instance
_CONTEXT = Context() _CONTEXT = Context()
@@ -35,7 +35,7 @@ def set_context(
slot_mapping=None, slot_mapping=None,
context_lens=None, context_lens=None,
block_tables=None, block_tables=None,
sparse_prefill_policy=None, attention_policy=None,
): ):
global _CONTEXT global _CONTEXT
_CONTEXT = Context( _CONTEXT = Context(
@@ -47,7 +47,7 @@ def set_context(
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
context_lens=context_lens, context_lens=context_lens,
block_tables=block_tables, block_tables=block_tables,
sparse_prefill_policy=sparse_prefill_policy, attention_policy=attention_policy,
) )

View File

@@ -32,11 +32,14 @@ def run_needle_test(
enable_cpu_offload: bool = False, enable_cpu_offload: bool = False,
enable_quest: bool = False, enable_quest: bool = False,
enable_minference: bool = False, enable_minference: bool = False,
enable_xattn: bool = False,
sparse_topk: int = 8, sparse_topk: int = 8,
sparse_threshold: int = 4, sparse_threshold: int = 4,
minference_budget: float = 0.3, minference_budget: float = 0.3,
minference_vertical: int = 1000, minference_vertical: int = 1000,
minference_slash: int = 6096, minference_slash: int = 6096,
xattn_threshold: float = 0.9,
xattn_use_bsa: bool = True,
gpu_utilization: float = 0.9, gpu_utilization: float = 0.9,
enforce_eager: bool = True, enforce_eager: bool = True,
verbose: bool = True, verbose: bool = True,
@@ -56,11 +59,14 @@ def run_needle_test(
enable_cpu_offload: Enable CPU offload mode enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K) enable_quest: Enable Quest sparse attention (decode-only Top-K)
enable_minference: Enable MInference sparse prefill (GPU-only) enable_minference: Enable MInference sparse prefill (GPU-only)
enable_xattn: Enable XAttention sparse prefill with BSA
sparse_topk: Top-K blocks for Quest sparse_topk: Top-K blocks for Quest
sparse_threshold: Apply sparse only when blocks > threshold sparse_threshold: Apply sparse only when blocks > threshold
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode) minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
minference_vertical: Fixed vertical_size (only used when budget=None) minference_vertical: Fixed vertical_size (only used when budget=None)
minference_slash: Fixed slash_size (only used when budget=None) minference_slash: Fixed slash_size (only used when budget=None)
xattn_threshold: XAttention block selection threshold (0-1)
xattn_use_bsa: Use Block Sparse Attention library
gpu_utilization: GPU memory utilization fraction gpu_utilization: GPU memory utilization fraction
verbose: Print detailed output verbose: Print detailed output
@@ -68,7 +74,9 @@ def run_needle_test(
True if test passed, False otherwise True if test passed, False otherwise
""" """
# Determine sparse policy # Determine sparse policy
if enable_minference: if enable_xattn:
sparse_policy = SparsePolicyType.XATTN
elif enable_minference:
sparse_policy = SparsePolicyType.MINFERENCE sparse_policy = SparsePolicyType.MINFERENCE
elif enable_quest: elif enable_quest:
sparse_policy = SparsePolicyType.QUEST sparse_policy = SparsePolicyType.QUEST
@@ -94,6 +102,8 @@ def run_needle_test(
print(f" MInference: adaptive (budget={minference_budget})") print(f" MInference: adaptive (budget={minference_budget})")
else: else:
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})") print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
if enable_xattn:
print(f" XAttention: threshold={xattn_threshold}, use_bsa={xattn_use_bsa}")
print(f"{'='*60}\n") print(f"{'='*60}\n")
# 1. Initialize LLM # 1. Initialize LLM
@@ -111,7 +121,7 @@ def run_needle_test(
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
# Set sparse policy (can be used with or without offload) # Set sparse policy (can be used with or without offload)
if enable_minference or enable_quest: if enable_minference or enable_quest or enable_xattn:
llm_kwargs["sparse_policy"] = sparse_policy llm_kwargs["sparse_policy"] = sparse_policy
# MInference params (works with both GPU-only and offload mode) # MInference params (works with both GPU-only and offload mode)
@@ -120,6 +130,11 @@ def run_needle_test(
llm_kwargs["minference_vertical_size"] = minference_vertical llm_kwargs["minference_vertical_size"] = minference_vertical
llm_kwargs["minference_slash_size"] = minference_slash llm_kwargs["minference_slash_size"] = minference_slash
# XAttention params
if enable_xattn:
llm_kwargs["xattn_threshold"] = xattn_threshold
llm_kwargs["xattn_use_bsa"] = xattn_use_bsa
llm = LLM(model_path, **llm_kwargs) llm = LLM(model_path, **llm_kwargs)
# 2. Generate needle prompt # 2. Generate needle prompt
@@ -224,6 +239,11 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)" help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
) )
parser.add_argument(
"--enable-xattn",
action="store_true",
help="Enable XAttention sparse prefill with Block Sparse Attention"
)
parser.add_argument( parser.add_argument(
"--sparse-topk", "--sparse-topk",
type=int, type=int,
@@ -254,6 +274,17 @@ if __name__ == "__main__":
default=6096, default=6096,
help="Fixed slash_size (only used when budget=0)" help="Fixed slash_size (only used when budget=0)"
) )
parser.add_argument(
"--xattn-threshold",
type=float,
default=0.9,
help="XAttention block selection threshold (0-1, higher=more blocks)"
)
parser.add_argument(
"--xattn-no-bsa",
action="store_true",
help="Disable Block Sparse Attention (use FlashAttention fallback)"
)
parser.add_argument( parser.add_argument(
"--gpu-utilization", "--gpu-utilization",
type=float, type=float,
@@ -291,11 +322,14 @@ if __name__ == "__main__":
enable_cpu_offload=args.enable_offload, enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest, enable_quest=args.enable_quest,
enable_minference=args.enable_minference, enable_minference=args.enable_minference,
enable_xattn=args.enable_xattn,
sparse_topk=args.sparse_topk, sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold, sparse_threshold=args.sparse_threshold,
minference_budget=minference_budget, minference_budget=minference_budget,
minference_vertical=args.minference_vertical, minference_vertical=args.minference_vertical,
minference_slash=args.minference_slash, minference_slash=args.minference_slash,
xattn_threshold=args.xattn_threshold,
xattn_use_bsa=not args.xattn_no_bsa,
gpu_utilization=args.gpu_utilization, gpu_utilization=args.gpu_utilization,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
verbose=True, verbose=True,