[WIP] need refactor.

This commit is contained in:
Zijie Tian
2026-01-22 22:20:34 +08:00
parent 69b779e252
commit 5fb0f67295
11 changed files with 514 additions and 548 deletions

View File

@@ -62,6 +62,7 @@ class Config:
xattn_keep_sink: bool = False # Always keep first block (sink tokens)
xattn_keep_recent: bool = False # Always keep recent diagonal blocks
xattn_norm: float = 1.0 # Normalization factor for attention scores
xattn_use_bsa: bool = True # Use Block Sparse Attention library (requires installation)
def __post_init__(self):
assert os.path.isdir(self.model)

View File

@@ -57,8 +57,8 @@ class ModelRunner:
load_model(self.model, config.model)
self.sampler = GreedySampler()
# Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache)
self.sparse_prefill_policy = None
# Initialize attention_policy before warmup (will be configured in allocate_kv_cache)
self.attention_policy = None
#> Disable warmup for debugging
self.warmup_model()
@@ -178,38 +178,35 @@ class ModelRunner:
# Create KV cache manager using factory
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
# Create sparse prefill policy
# This is used for both GPU-only and CPU offload modes when policy supports prefill
self.sparse_prefill_policy = None
if config.sparse_policy != SparsePolicyType.FULL:
from nanovllm.kvcache.sparse import create_sparse_policy
# Create attention policy (always, including FULL)
# In layerwise offload mode, all attention goes through the policy
from nanovllm.kvcache.sparse import create_attention_policy
# Get policy-specific parameters based on type
if config.sparse_policy == SparsePolicyType.XATTN:
policy_kwargs = {
"stride": config.xattn_stride,
"threshold": config.xattn_threshold,
"chunk_size": config.xattn_chunk_size,
"use_triton": config.xattn_use_triton,
"keep_sink": config.xattn_keep_sink,
"keep_recent": config.xattn_keep_recent,
"norm": config.xattn_norm,
}
else: # MINFERENCE or others
policy_kwargs = {
"vertical_size": config.minference_vertical_size,
"slash_size": config.minference_slash_size,
"adaptive_budget": config.minference_adaptive_budget,
"num_sink_tokens": config.minference_num_sink_tokens,
"num_recent_diags": config.minference_num_recent_diags,
}
# Get policy-specific parameters based on type
if config.sparse_policy == SparsePolicyType.XATTN:
policy_kwargs = {
"stride": config.xattn_stride,
"threshold": config.xattn_threshold,
"chunk_size": config.xattn_chunk_size,
"use_triton": config.xattn_use_triton,
"keep_sink": config.xattn_keep_sink,
"keep_recent": config.xattn_keep_recent,
"norm": config.xattn_norm,
"use_bsa": config.xattn_use_bsa,
}
elif config.sparse_policy == SparsePolicyType.MINFERENCE:
policy_kwargs = {
"vertical_size": config.minference_vertical_size,
"slash_size": config.minference_slash_size,
"adaptive_budget": config.minference_adaptive_budget,
"num_sink_tokens": config.minference_num_sink_tokens,
"num_recent_diags": config.minference_num_recent_diags,
}
else: # FULL or QUEST
policy_kwargs = {}
policy = create_sparse_policy(config.sparse_policy, **policy_kwargs)
# Only use if policy supports sparse prefill
if policy.supports_prefill:
self.sparse_prefill_policy = policy
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}")
self.attention_policy = create_attention_policy(config.sparse_policy, **policy_kwargs)
logger.info(f"Attention policy: {self.attention_policy}")
# Allocate cache through manager
self.kvcache_manager.allocate_cache(
@@ -395,7 +392,7 @@ class ModelRunner:
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
slot_mapping, None, block_tables,
sparse_prefill_policy=self.sparse_prefill_policy)
attention_policy=self.attention_policy)
return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]):
@@ -592,21 +589,11 @@ class ModelRunner:
# RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k)
# Sparse or Full attention (uses k, v directly - before store!)
if self.sparse_prefill_policy is not None:
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# Compute attention using policy (uses k, v directly - before store!)
attn_output = self.attention_policy.compute_prefill(
q, k, v, layer_id,
softmax_scale=layer.self_attn.attn.scale,
)
# O projection
attn_output = attn_output.view(total_tokens, -1)
@@ -872,23 +859,11 @@ class ModelRunner:
# RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k)
# Sparse or Full attention
if self.sparse_prefill_policy is not None:
# MInference or other sparse prefill policy
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
# Full attention using FlashAttention
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale,
causal=True,
)
# Compute attention using policy
attn_output = self.attention_policy.compute_prefill(
q, k, v, layer_id,
softmax_scale=layer.self_attn.attn.scale,
)
# O projection
attn_output = attn_output.view(total_tokens, -1)

View File

@@ -1,49 +1,56 @@
"""
Sparse Attention Policy module.
Attention Policy module for layerwise offload mode.
Provides pluggable policies for selecting which KV blocks to load
during chunked attention with CPU offload.
Provides pluggable policies for attention computation:
- FullAttentionPolicy: Standard FlashAttention (no sparsity)
- XAttentionPolicy: Sparse prefill using XAttention algorithm
- MInferencePolicy: MInference sparse attention
- QuestPolicy: Quest block selection (for chunked offload)
Usage:
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType
# Create policy using factory function
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
# Use policy for attention
attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
# Or create custom policy
class MyPolicy(SparsePolicy):
class MyPolicy(AttentionPolicy):
supports_prefill = True
supports_decode = True
def select_blocks(self, available_blocks, ctx):
return available_blocks[:5] # Just first 5 blocks
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
# Custom attention computation
...
"""
from nanovllm.config import SparsePolicyType
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.minference import MInferencePolicy
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy:
"""
Create a sparse policy instance from an enum type.
Create an attention policy instance from an enum type.
The returned policy is not yet initialized. Call policy.initialize()
or let the framework call it during KV cache allocation.
All attention (including full attention) goes through a policy in layerwise
offload mode. The policy is responsible for computing prefill/decode attention.
Args:
policy_type: SparsePolicyType enum value
policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST)
**kwargs: Policy-specific configuration options
Returns:
SparsePolicy instance (not initialized)
AttentionPolicy instance
Example:
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
policy.initialize(num_layers=28, num_kv_heads=8, ...)
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
"""
if policy_type == SparsePolicyType.FULL:
return FullAttentionPolicy()
@@ -75,21 +82,32 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
keep_sink=kwargs.get("keep_sink", False),
keep_recent=kwargs.get("keep_recent", False),
norm=kwargs.get("norm", 1.0),
use_bsa=kwargs.get("use_bsa", True),
)
else:
raise ValueError(f"Unknown policy type: {policy_type}")
# Backward compatibility alias
create_sparse_policy = create_attention_policy
__all__ = [
# New interface
"AttentionPolicy",
"create_attention_policy",
# Backward compatibility
"SparsePolicy",
"create_sparse_policy",
# Common types
"PolicyContext",
"SparsePolicyType",
# Policy implementations
"FullAttentionPolicy",
"QuestPolicy",
"QuestConfig",
"BlockMetadataManager",
"MInferencePolicy",
"XAttentionPolicy",
"create_sparse_policy",
]

View File

@@ -1,20 +1,21 @@
"""
Full attention policy - loads all blocks (no sparsity).
Full attention policy - standard FlashAttention without sparsity.
This serves as a baseline and default policy when sparse
attention is not needed.
"""
from typing import List
from .policy import SparsePolicy, PolicyContext
from typing import Optional
import torch
from .policy import AttentionPolicy
class FullAttentionPolicy(SparsePolicy):
class FullAttentionPolicy(AttentionPolicy):
"""
Full attention policy that loads all available blocks.
Full attention policy using FlashAttention (no sparsity).
This is the default behavior with no sparsity - all previous
KV cache blocks are loaded for each query chunk.
This is the default behavior with standard causal attention.
All tokens attend to all previous tokens.
Use this as:
- A baseline for comparing sparse policies
@@ -25,15 +26,55 @@ class FullAttentionPolicy(SparsePolicy):
# Full attention supports both prefill and decode
supports_prefill = True
supports_decode = True
requires_block_selection = False # Load all blocks, no selective loading
def select_blocks(
def estimate(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
"""Return all blocks - no sparsity."""
return available_blocks
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Full attention - no sparse mask needed.
Returns None to indicate full attention should be used.
"""
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute full causal attention using FlashAttention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def __repr__(self) -> str:
return "FullAttentionPolicy()"

View File

@@ -10,10 +10,10 @@ from typing import List, Tuple, Optional
import torch
import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.policy import AttentionPolicy, PolicyContext
class MInferencePolicy(SparsePolicy):
class MInferencePolicy(AttentionPolicy):
"""
MInference sparse prefill policy using vertical + slash pattern.
@@ -347,6 +347,33 @@ class MInferencePolicy(SparsePolicy):
return o
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute MInference sparse prefill attention.
This is the new unified interface for attention policies.
Delegates to sparse_prefill_attention (ignores softmax_scale as MInference
computes it internally from head_dim).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (unused, computed internally)
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
return self.sparse_prefill_attention(q, k, v, layer_id)
def __repr__(self) -> str:
return (f"MInferencePolicy("
f"adaptive_budget={self.adaptive_budget}, "

View File

@@ -1,13 +1,18 @@
"""
Base class for sparse attention policies.
Base class for attention policies in layerwise offload mode.
Sparse attention policies determine which KV cache blocks to load
from CPU for each query chunk during chunked attention computation.
AttentionPolicy defines the interface for all attention computation,
including full attention and sparse attention methods like XAttention.
Key methods:
- estimate(): Compute sparse attention mask (optional, returns None for full attention)
- compute_prefill(): Compute prefill attention
- compute_decode(): Compute decode attention (default implementation provided)
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Any
from typing import List, Optional, Tuple
import torch
# Import SparsePolicyType from config to avoid circular imports
@@ -17,10 +22,10 @@ from nanovllm.config import SparsePolicyType
@dataclass
class PolicyContext:
"""
Context passed to sparse policy for block selection.
Context passed to attention policy for block selection.
This dataclass contains all information needed by a sparse policy
to decide which blocks to load for the current query chunk.
This dataclass contains all information needed by an attention policy
for sparse estimation and attention computation.
"""
query_chunk_idx: int
@@ -49,40 +54,41 @@ class PolicyContext:
"""Total KV sequence length so far (for reference)."""
class SparsePolicy(ABC):
class AttentionPolicy(ABC):
"""
Abstract base class for sparse attention policies.
Base class for attention policies in layerwise offload mode.
Subclass this and implement select_blocks() to create custom
sparse attention patterns. The policy receives context about
the current query chunk and returns which KV blocks to load.
All attention computation goes through a policy, including both
full attention and sparse attention methods.
The policy interface is designed for layerwise offload where:
- The entire KV cache for a layer is on GPU during computation
- No need for block loading from CPU during attention
- estimate() returns a sparse mask (or None for full attention)
- compute_prefill()/compute_decode() perform the actual attention
Attributes:
supports_prefill: Whether this policy can be used for prefill phase.
supports_decode: Whether this policy can be used for decode phase.
Example:
class MySparsePolicy(SparsePolicy):
supports_prefill = False # decode-only policy
class MyPolicy(AttentionPolicy):
supports_prefill = True
supports_decode = True
def select_blocks(self, available_blocks, ctx):
# Load first block and last 2 blocks
if len(available_blocks) <= 3:
return available_blocks
return [available_blocks[0]] + available_blocks[-2:]
def estimate(self, q, k, layer_id):
# Return sparse mask or None
return None
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
# Compute attention
return flash_attn_varlen_func(q, k, v, ...)
"""
# Compatibility flags - override in subclasses
supports_prefill: bool = True
supports_decode: bool = True
# Whether this policy requires selective block loading during decode
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
# Example: MInference=False (only affects attention), Quest=True (affects load)
requires_block_selection: bool = False
def initialize(
self,
num_layers: int,
@@ -96,7 +102,7 @@ class SparsePolicy(ABC):
Initialize policy resources.
Called by the framework after KV cache is allocated. Override this
to create metadata structures (e.g., BlockMetadataManager for Quest).
to create metadata structures or pre-allocate buffers.
Default implementation does nothing.
Args:
@@ -109,76 +115,98 @@ class SparsePolicy(ABC):
"""
pass
@abstractmethod
def select_blocks(
def estimate(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Select which KV blocks to load for the current query chunk.
Estimate sparse attention mask.
This is the core method that defines the sparse attention pattern.
The returned blocks will be loaded from CPU to GPU for attention
computation against the current query chunk.
For sparse policies (e.g., XAttention), computes block-level importance
and returns a boolean mask indicating which blocks to attend.
For full attention policy, returns None.
This corresponds to xattn_estimate() in COMPASS.
Args:
available_blocks: List of CPU block IDs that contain KV cache
from previous chunks. These are ordered by
their position in the sequence.
ctx: PolicyContext with information about the current query
chunk, layer, phase (prefill/decode), etc.
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
List of block IDs to load (must be a subset of available_blocks).
The order may affect performance (sequential access is faster).
Returning [] means no previous blocks will be loaded.
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
or None for full attention
"""
pass
return None
def on_prefill_offload(
@abstractmethod
def compute_prefill(
self,
cpu_block_id: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
softmax_scale: float,
) -> torch.Tensor:
"""
Hook called when a block is offloaded during prefill phase.
Compute prefill attention.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
Override this to collect metadata about blocks (e.g., min/max keys
for Quest-style selection). Default implementation does nothing.
The entire KV cache for this layer is on GPU. Compute attention
between Q and K/V, optionally using sparse mask from estimate().
Args:
cpu_block_id: The CPU block ID that will be written
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
pass
def on_decode_offload(
def compute_decode(
self,
cpu_block_id: int,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
softmax_scale: float,
) -> torch.Tensor:
"""
Hook called when a block is offloaded during decode phase.
Compute decode attention.
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
Override this to update metadata about blocks. Default implementation
does nothing.
KV is provided from ring buffer, containing prefill tokens + decoded tokens.
Default implementation uses FlashAttention.
Args:
cpu_block_id: The CPU block ID that will be written
q: Query tensor [1, num_heads, head_dim]
k: Key tensor [context_len+1, num_kv_heads, head_dim]
v: Value tensor [context_len+1, num_kv_heads, head_dim]
layer_id: Transformer layer index
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
softmax_scale: Softmax scaling factor
Returns:
Attention output [1, num_heads, head_dim]
"""
pass
from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None:
"""
@@ -189,32 +217,9 @@ class SparsePolicy(ABC):
"""
pass
def sparse_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
) -> torch.Tensor:
"""
Compute sparse attention for prefill phase.
This method is called when supports_prefill=True and the policy
is used for GPU-only sparse prefill (no CPU offload).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current transformer layer index
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
"Set supports_prefill=False or implement this method."
)
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# Backward compatibility alias
SparsePolicy = AttentionPolicy

View File

@@ -11,7 +11,7 @@ import logging
import torch
from dataclasses import dataclass
from typing import List, Tuple, Optional
from .policy import SparsePolicy, PolicyContext
from .policy import AttentionPolicy, PolicyContext
logger = logging.getLogger(__name__)
@@ -137,7 +137,7 @@ class QuestConfig:
"""Always include this many recent blocks (last N blocks), in addition to Top-K."""
class QuestPolicy(SparsePolicy):
class QuestPolicy(AttentionPolicy):
"""
Quest-style Top-K block selection using min/max key bounds.
@@ -317,6 +317,25 @@ class QuestPolicy(SparsePolicy):
if self.metadata is not None:
self.metadata.reset()
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Quest does not support prefill - raises error.
Quest is a decode-only policy for selective block loading.
For prefill, use FullAttentionPolicy or XAttentionPolicy.
"""
raise NotImplementedError(
"QuestPolicy does not support prefill. "
"Use FullAttentionPolicy or XAttentionPolicy for prefill."
)
def __repr__(self) -> str:
return (
f"QuestPolicy(topk={self.config.topk_blocks}, "

View File

@@ -4,48 +4,56 @@ XAttention sparse attention policy for nano-vllm.
Implements the XAttention algorithm from COMPASS, using chunked estimation
and block sparse attention for efficient long-context inference.
Architecture:
XAttention = Estimate (Triton) + Compute (BSA)
- Estimate: xattn_estimate() computes block-level importance scores
- Compute: block_sparse_attn_func() executes sparse attention
Reference: COMPASS/compass/src/Xattention.py
"""
import math
from typing import List, Optional
from typing import Optional
import torch
import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.kernels import (
flat_group_gemm_fuse_reshape,
softmax_fuse_block_sum,
)
from nanovllm.kvcache.sparse.utils import find_blocks_chunked
from nanovllm.kvcache.sparse.policy import AttentionPolicy
# BSA block size is fixed at 128 (hardcoded in block_sparse_attn)
BSA_BLOCK_SIZE = 128
class XAttentionPolicy(SparsePolicy):
class XAttentionPolicy(AttentionPolicy):
"""
XAttention sparse prefill policy using chunked estimation + block sparse attention.
This policy estimates sparse attention patterns by:
1. Chunked QK computation using Triton kernels
1. Chunked QK computation using Triton kernels (via nanovllm.ops.xattn)
2. Block-wise softmax with importance scores
3. Block selection based on threshold
4. Block sparse attention computation
4. Block sparse attention computation using MIT-HAN-LAB BSA library
The key method is estimate() which calls xattn_estimate() from nanovllm.ops
to compute the sparse attention mask.
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
BSA library: https://github.com/mit-han-lab/Block-Sparse-Attention
"""
supports_prefill = True
supports_decode = False # XAttention is prefill-only
requires_block_selection = False # Only affects attention computation
supports_decode = True # Uses default FlashAttention for decode
def __init__(
self,
stride: int = 8,
threshold: float = 0.9,
chunk_size: Optional[int] = None,
block_size: int = 128,
chunk_size: int = 16384,
use_triton: bool = True,
keep_sink: bool = False,
keep_recent: bool = False,
norm: float = 1.0,
use_bsa: bool = True,
):
"""
Initialize XAttention policy.
@@ -53,19 +61,28 @@ class XAttentionPolicy(SparsePolicy):
Args:
stride: Stride for reorganizing Q/K (default: 8)
threshold: Block selection threshold, 0-1 (default: 0.9)
chunk_size: Chunk size for estimation (auto if None)
block_size: Block size for sparse attention (default: 128, must match BSA)
chunk_size: Chunk size for estimation (default: 16384)
use_triton: Use Triton kernels (requires SM 80+)
keep_sink: Always keep first block (sink tokens)
keep_recent: Always keep recent diagonal blocks
norm: Normalization factor for attention scores
use_bsa: Use Block Sparse Attention library (default: True)
"""
self.stride = stride
self.threshold = threshold
self.block_size = block_size
self.chunk_size = chunk_size
self.use_triton = use_triton
self.keep_sink = keep_sink
self.keep_recent = keep_recent
self.norm = norm
self.use_bsa = use_bsa
# BSA requires block_size = 128
if self.use_bsa and self.block_size != BSA_BLOCK_SIZE:
print(f"XAttention: BSA requires block_size=128, adjusting from {self.block_size}")
self.block_size = BSA_BLOCK_SIZE
# Check Triton availability
if self.use_triton:
@@ -79,379 +96,206 @@ class XAttentionPolicy(SparsePolicy):
self.use_triton = False
print("XAttention: Triton not available. Falling back to PyTorch.")
def select_blocks(
# Check BSA availability
if self.use_bsa:
try:
from block_sparse_attn import block_sparse_attn_func
except ImportError:
self.use_bsa = False
print("XAttention: block_sparse_attn not available. Falling back to FlashAttention.")
def estimate(
self,
available_blocks: List[int],
ctx: PolicyContext,
) -> List[int]:
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Select blocks for decode phase.
Estimate sparse attention mask using XAttention algorithm.
XAttention is prefill-only, so this method is only used as a fallback.
Returns all available blocks by default.
Calls xattn_estimate() from nanovllm.ops.xattn to compute block-level
importance scores and generate a sparse boolean mask.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
or None if estimation fails (fallback to full attention)
"""
# XAttention is prefill-only, but we need to implement this abstract method
# Since requires_block_selection=False, this won't be called for loading
return available_blocks
try:
from nanovllm.ops.xattn import xattn_estimate
def sparse_prefill_attention(
seq_len, num_heads, head_dim = q.shape
num_kv_heads = k.shape[1]
# Convert to [batch, heads, seq, dim] format expected by xattn_estimate
q_bhsd = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
k_bhsd = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, seq_len, head_dim]
# Handle GQA: expand k to match q heads for estimation
if num_kv_heads != num_heads:
# GQA: expand k by repeating
repeat_factor = num_heads // num_kv_heads
k_bhsd = k_bhsd.repeat(1, repeat_factor, 1, 1)
# Call xattn_estimate
attn_sums, sparse_mask = xattn_estimate(
q_bhsd, k_bhsd,
block_size=self.block_size,
stride=self.stride,
norm=self.norm,
threshold=self.threshold,
chunk_size=self.chunk_size,
use_triton=self.use_triton,
causal=True,
keep_sink=self.keep_sink,
keep_recent=self.keep_recent,
)
return sparse_mask
except Exception as e:
# If estimation fails, return None to use full attention
print(f"XAttention estimate failed: {e}, falling back to full attention")
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute XAttention sparse attention for prefill.
Compute XAttention sparse prefill attention.
Flow:
1. Call estimate() to get sparse mask
2. If mask is None or BSA unavailable, use full FlashAttention
3. Otherwise, use block_sparse_attn_func with mask
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current transformer layer index
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
seq_len = q.shape[0]
num_heads = q.shape[1]
head_dim = q.shape[2]
num_kv_heads = k.shape[1]
# If BSA is disabled, use full attention directly (skip estimation)
if not self.use_bsa:
return self._full_attention(q, k, v, softmax_scale)
# Use FlashAttention directly for CPU offload mode
# FlashAttention supports GQA natively
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Step 1: Estimate sparse mask
sparse_mask = self.estimate(q, k, layer_id)
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
# Step 2: Compute attention
if sparse_mask is None:
# Estimation failed, fallback to full FlashAttention
return self._full_attention(q, k, v, softmax_scale)
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=1.0 / math.sqrt(head_dim),
causal=True,
)
# Use block sparse attention with mask
return self._block_sparse_attention(q, k, v, sparse_mask, softmax_scale)
return attn_output
except Exception as e:
# Fallback: PyTorch SDPA (supports GQA natively)
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
is_causal=True,
scale=1.0 / math.sqrt(head_dim)
)
return attn_output
def _xattn_offload_prefill(
def _block_sparse_attention(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
causal: bool = True,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sparse_mask: torch.Tensor,
softmax_scale: float,
) -> torch.Tensor:
"""
Simplified XAttention prefill for CPU offload mode.
Uses FlashAttention with full context since chunked estimation
with full key_states requires special handling.
"""
batch_size, num_heads, q_len, head_dim = query_states.shape
_, _, k_len, _ = key_states.shape
# Use FlashAttention with full context
# In offload mode, keys are already on CPU and loaded as needed
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Convert to [seq, heads, dim] format
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=k_len,
softmax_scale=1.0 / math.sqrt(head_dim),
causal=causal,
)
# Convert back to [batch, seq, heads, dim]
attn_output = attn_output.unsqueeze(0).transpose(1, 2) # [1, q_len, num_heads, head_dim]
return attn_output
except Exception as e:
# Final fallback: PyTorch SDPA
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=None,
is_causal=causal,
scale=1.0 / math.sqrt(head_dim)
)
return attn_output
def _xattn_prefill(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
stride: int,
norm: float,
threshold: float,
block_size: int = 128,
use_triton: bool = True,
causal: bool = True,
chunk_size: Optional[int] = None,
keep_sink: bool = False,
keep_recent: bool = False,
) -> torch.Tensor:
"""
XAttention prefill implementation.
Compute block sparse attention using MIT-HAN-LAB BSA library.
Args:
query_states: [batch, num_heads, q_len, head_dim]
key_states: [batch, num_heads, k_len, head_dim]
value_states: [batch, num_heads, k_len, head_dim]
... other params
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
sparse_mask: Block mask [batch, num_heads, q_blocks, k_blocks]
softmax_scale: Softmax scaling factor
Returns:
Attention output [batch, q_len, num_heads, head_dim]
Attention output [seq_len, num_heads, head_dim]
"""
batch_size, num_heads, k_len, head_dim = key_states.shape
_, _, q_len, _ = query_states.shape
from block_sparse_attn import block_sparse_attn_func
# Auto-compute chunk_size if not specified
if chunk_size is None:
chunk_size = int(
max(
min(
max(2048, 1 << (k_len - 1).bit_length()),
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()),
),
2048,
)
)
seq_len, num_heads, head_dim = q.shape
num_kv_heads = k.shape[1]
# Phase 1: Estimate sparse pattern
attn_sums, approx_simple_mask = self._xattn_estimate(
query_states,
key_states,
block_size=block_size,
stride=stride,
norm=norm,
threshold=threshold,
chunk_size=chunk_size,
use_triton=use_triton,
causal=causal,
keep_sink=keep_sink,
keep_recent=keep_recent,
)
# Handle GQA: expand K/V to match Q heads
if num_kv_heads != num_heads:
repeat_factor = num_heads // num_kv_heads
k = k.repeat_interleave(repeat_factor, dim=1)
v = v.repeat_interleave(repeat_factor, dim=1)
# Phase 2: Block sparse attention
# For now, use FlashAttention as fallback since block_sparse_attn_func may not be available
attn_output = self._block_sparse_attention_fallback(
query_states, key_states, value_states,
approx_simple_mask, block_size, q_len, k_len
# Cumulative sequence lengths (batch=1)
cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
# Head mask type: 1 for all heads using block sparse
head_mask_type = torch.ones(num_heads, dtype=torch.int32, device=q.device)
# Trim sparse_mask to actual block counts
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
block_mask = sparse_mask[:, :, :q_blocks, :k_blocks].contiguous()
# Call BSA
attn_output = block_sparse_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
None, # streaming_info (left_mask)
block_mask,
seq_len, seq_len,
p_dropout=0.0,
deterministic=True,
softmax_scale=softmax_scale,
is_causal=True,
)
return attn_output
def _xattn_estimate(
def _full_attention(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
block_size: int,
stride: int,
norm: float = 1,
softmax: bool = True,
threshold: float = 0.9,
chunk_size: int = 16384,
use_triton: bool = True,
causal: bool = True,
keep_sink: bool = False,
keep_recent: bool = False,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: float,
) -> torch.Tensor:
"""
Estimate sparse attention pattern using chunked computation.
Compute full causal attention using FlashAttention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
softmax_scale: Softmax scaling factor
Returns:
attn_sums: [batch, heads, q_blocks, k_blocks] - importance scores
simple_masks: [batch, heads, q_blocks, k_blocks] - boolean masks
Attention output [seq_len, num_heads, head_dim]
"""
batch_size, num_kv_head, k_len, head_dim = key_states.shape
batch_size, num_q_head, q_len, head_dim = query_states.shape
from flash_attn.flash_attn_interface import flash_attn_varlen_func
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
k_block_num = (k_len + k_num_to_pad) // block_size
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
q_block_num = (q_len + q_num_to_pad) // block_size
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
# Pad inputs
if k_num_to_pad > 0:
pad_key_states = F.pad(key_states, (0, 0, 0, k_num_to_pad), value=0)
else:
pad_key_states = key_states
if q_num_to_pad > 0:
pad_query_states = F.pad(query_states, (0, 0, 0, q_num_to_pad), value=0)
else:
pad_query_states = query_states
reshaped_chunk_size = chunk_size // stride
reshaped_block_size = block_size // stride
k_reshaped_seq_len = (k_len + k_num_to_pad) // stride
attn_sum_list = []
simple_mask_list = []
for chunk_idx in range(q_chunk_num):
if use_triton:
# Triton GEMM + Softmax
attn_weights_slice = flat_group_gemm_fuse_reshape(
pad_query_states[:, :, (chunk_idx * reshaped_chunk_size) * stride : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size) * stride, :],
pad_key_states,
stride,
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
is_causal=causal,
)
attn_sum = softmax_fuse_block_sum(
attn_weights_slice,
reshaped_block_size,
min(4096, reshaped_block_size),
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
k_reshaped_seq_len - (k_num_to_pad // stride),
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
is_causal=causal,
)
else:
# PyTorch fallback
chunk_size_actual = reshaped_chunk_size
chunk_start = chunk_idx * chunk_size_actual
chunk_end = chunk_start + chunk_size_actual
chunked_query = pad_query_states[:, :, chunk_start * stride:chunk_end * stride:stride, :]
attn_weights_slice = torch.matmul(chunked_query, pad_key_states.transpose(2, 3))
attn_weights_slice = attn_weights_slice / math.sqrt(head_dim) / stride / norm
if causal:
causal_mask = torch.zeros((batch_size, num_q_head, chunk_size_actual, chunk_size_actual * k_chunk_num), device=key_states.device)
causal_mask[:, :, :, -(k_num_to_pad // stride):] = float("-inf")
# ... more causal mask logic ...
attn_weights_slice = attn_weights_slice + causal_mask
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1, dtype=torch.float32)
attn_sum = attn_weights_slice.view(batch_size, num_q_head, chunk_size_actual // reshaped_block_size, reshaped_block_size, -1).sum(dim=-1).sum(dim=-2)
# Find blocks based on threshold
simple_mask = find_blocks_chunked(
attn_sum,
k_block_num - q_block_num + chunk_idx * (reshaped_chunk_size // reshaped_block_size),
threshold,
None,
decoding=False,
mode="prefill",
causal=causal,
)
attn_sum_list.append(attn_sum)
simple_mask_list.append(simple_mask)
attn_sums = torch.cat(attn_sum_list, dim=-2)
simple_masks = torch.cat(simple_mask_list, dim=-2)
# Apply causal mask to block masks
if causal:
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=key_states.device), diagonal=0),
simple_masks[:, :, -q_block_num:, -q_block_num:],
False,
)
if keep_sink:
simple_masks[:, :, 0, :] = True
if keep_recent:
eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool)
eye_matrix_expanded = eye_matrix.unsqueeze(0).unsqueeze(0).expand(1, num_q_head, q_block_num, q_block_num)
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:]
)
return attn_sums, simple_masks
def _block_sparse_attention_fallback(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
mask: torch.Tensor,
block_size: int,
q_len: int,
k_len: int,
) -> torch.Tensor:
"""
Fallback implementation using FlashAttention.
Since block_sparse_attn_func may not be available in all environments,
this uses standard FlashAttention with full attention.
"""
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
batch_size, num_heads, _, head_dim = query_states.shape
# Convert to [seq, heads, dim] format
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=k_len,
softmax_scale=1.0 / math.sqrt(head_dim),
causal=True,
)
# Convert back to [batch, seq, heads, dim]
attn_output = attn_output.unsqueeze(0).transpose(1, 2)
return attn_output
except Exception as e:
# Final fallback: PyTorch SDPA
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=None,
is_causal=True,
scale=1.0 / math.sqrt(query_states.shape[-1])
)
return attn_output
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def reset(self) -> None:
"""Reset policy state (no state to reset for XAttention)."""
@@ -461,4 +305,6 @@ class XAttentionPolicy(SparsePolicy):
return (f"XAttentionPolicy("
f"stride={self.stride}, "
f"threshold={self.threshold}, "
f"use_triton={self.use_triton})")
f"block_size={self.block_size}, "
f"use_triton={self.use_triton}, "
f"use_bsa={self.use_bsa})")

View File

@@ -98,10 +98,10 @@ class Attention(nn.Module):
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
elif context.sparse_prefill_policy is not None:
# Sparse prefill (GPU-only) - delegate to policy
o = context.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, self.layer_id
elif context.attention_policy is not None:
# Attention via policy (GPU-only) - delegate to policy
o = context.attention_policy.compute_prefill(
q, k, v, self.layer_id, softmax_scale=self.scale
)
else:
o = flash_attn_varlen_func(q, k, v,

View File

@@ -14,9 +14,9 @@ class Context:
context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None
# Sparse prefill attention support (GPU-only path)
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
# Attention policy support (GPU-only path)
# When set, uses policy.compute_prefill() instead of FlashAttention
attention_policy: Any = None # AttentionPolicy instance
_CONTEXT = Context()
@@ -35,7 +35,7 @@ def set_context(
slot_mapping=None,
context_lens=None,
block_tables=None,
sparse_prefill_policy=None,
attention_policy=None,
):
global _CONTEXT
_CONTEXT = Context(
@@ -47,7 +47,7 @@ def set_context(
slot_mapping=slot_mapping,
context_lens=context_lens,
block_tables=block_tables,
sparse_prefill_policy=sparse_prefill_policy,
attention_policy=attention_policy,
)

View File

@@ -32,11 +32,14 @@ def run_needle_test(
enable_cpu_offload: bool = False,
enable_quest: bool = False,
enable_minference: bool = False,
enable_xattn: bool = False,
sparse_topk: int = 8,
sparse_threshold: int = 4,
minference_budget: float = 0.3,
minference_vertical: int = 1000,
minference_slash: int = 6096,
xattn_threshold: float = 0.9,
xattn_use_bsa: bool = True,
gpu_utilization: float = 0.9,
enforce_eager: bool = True,
verbose: bool = True,
@@ -56,11 +59,14 @@ def run_needle_test(
enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K)
enable_minference: Enable MInference sparse prefill (GPU-only)
enable_xattn: Enable XAttention sparse prefill with BSA
sparse_topk: Top-K blocks for Quest
sparse_threshold: Apply sparse only when blocks > threshold
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
minference_vertical: Fixed vertical_size (only used when budget=None)
minference_slash: Fixed slash_size (only used when budget=None)
xattn_threshold: XAttention block selection threshold (0-1)
xattn_use_bsa: Use Block Sparse Attention library
gpu_utilization: GPU memory utilization fraction
verbose: Print detailed output
@@ -68,7 +74,9 @@ def run_needle_test(
True if test passed, False otherwise
"""
# Determine sparse policy
if enable_minference:
if enable_xattn:
sparse_policy = SparsePolicyType.XATTN
elif enable_minference:
sparse_policy = SparsePolicyType.MINFERENCE
elif enable_quest:
sparse_policy = SparsePolicyType.QUEST
@@ -94,6 +102,8 @@ def run_needle_test(
print(f" MInference: adaptive (budget={minference_budget})")
else:
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
if enable_xattn:
print(f" XAttention: threshold={xattn_threshold}, use_bsa={xattn_use_bsa}")
print(f"{'='*60}\n")
# 1. Initialize LLM
@@ -111,7 +121,7 @@ def run_needle_test(
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
# Set sparse policy (can be used with or without offload)
if enable_minference or enable_quest:
if enable_minference or enable_quest or enable_xattn:
llm_kwargs["sparse_policy"] = sparse_policy
# MInference params (works with both GPU-only and offload mode)
@@ -120,6 +130,11 @@ def run_needle_test(
llm_kwargs["minference_vertical_size"] = minference_vertical
llm_kwargs["minference_slash_size"] = minference_slash
# XAttention params
if enable_xattn:
llm_kwargs["xattn_threshold"] = xattn_threshold
llm_kwargs["xattn_use_bsa"] = xattn_use_bsa
llm = LLM(model_path, **llm_kwargs)
# 2. Generate needle prompt
@@ -224,6 +239,11 @@ if __name__ == "__main__":
action="store_true",
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
)
parser.add_argument(
"--enable-xattn",
action="store_true",
help="Enable XAttention sparse prefill with Block Sparse Attention"
)
parser.add_argument(
"--sparse-topk",
type=int,
@@ -254,6 +274,17 @@ if __name__ == "__main__":
default=6096,
help="Fixed slash_size (only used when budget=0)"
)
parser.add_argument(
"--xattn-threshold",
type=float,
default=0.9,
help="XAttention block selection threshold (0-1, higher=more blocks)"
)
parser.add_argument(
"--xattn-no-bsa",
action="store_true",
help="Disable Block Sparse Attention (use FlashAttention fallback)"
)
parser.add_argument(
"--gpu-utilization",
type=float,
@@ -291,11 +322,14 @@ if __name__ == "__main__":
enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest,
enable_minference=args.enable_minference,
enable_xattn=args.enable_xattn,
sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold,
minference_budget=minference_budget,
minference_vertical=args.minference_vertical,
minference_slash=args.minference_slash,
xattn_threshold=args.xattn_threshold,
xattn_use_bsa=not args.xattn_no_bsa,
gpu_utilization=args.gpu_utilization,
enforce_eager=enforce_eager,
verbose=True,