From 5fb0f672955c0081337f02da9cd97ea15501bd36 Mon Sep 17 00:00:00 2001 From: Zijie Tian Date: Thu, 22 Jan 2026 22:20:34 +0800 Subject: [PATCH] [WIP] need refactor. --- nanovllm/config.py | 1 + nanovllm/engine/model_runner.py | 105 ++--- nanovllm/kvcache/sparse/__init__.py | 54 ++- nanovllm/kvcache/sparse/full_policy.py | 69 +++- nanovllm/kvcache/sparse/minference.py | 31 +- nanovllm/kvcache/sparse/policy.py | 197 ++++----- nanovllm/kvcache/sparse/quest.py | 23 +- nanovllm/kvcache/sparse/xattn.py | 526 +++++++++---------------- nanovllm/layers/attention.py | 8 +- nanovllm/utils/context.py | 10 +- tests/test_needle.py | 38 +- 11 files changed, 514 insertions(+), 548 deletions(-) diff --git a/nanovllm/config.py b/nanovllm/config.py index bba24df..12d9f63 100644 --- a/nanovllm/config.py +++ b/nanovllm/config.py @@ -62,6 +62,7 @@ class Config: xattn_keep_sink: bool = False # Always keep first block (sink tokens) xattn_keep_recent: bool = False # Always keep recent diagonal blocks xattn_norm: float = 1.0 # Normalization factor for attention scores + xattn_use_bsa: bool = True # Use Block Sparse Attention library (requires installation) def __post_init__(self): assert os.path.isdir(self.model) diff --git a/nanovllm/engine/model_runner.py b/nanovllm/engine/model_runner.py index da8c165..ff05fa2 100644 --- a/nanovllm/engine/model_runner.py +++ b/nanovllm/engine/model_runner.py @@ -57,8 +57,8 @@ class ModelRunner: load_model(self.model, config.model) self.sampler = GreedySampler() - # Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache) - self.sparse_prefill_policy = None + # Initialize attention_policy before warmup (will be configured in allocate_kv_cache) + self.attention_policy = None #> Disable warmup for debugging self.warmup_model() @@ -178,38 +178,35 @@ class ModelRunner: # Create KV cache manager using factory self.kvcache_manager: KVCacheManager = create_kvcache_manager(config) - # Create sparse prefill policy - # This is used for both GPU-only and CPU offload modes when policy supports prefill - self.sparse_prefill_policy = None - if config.sparse_policy != SparsePolicyType.FULL: - from nanovllm.kvcache.sparse import create_sparse_policy + # Create attention policy (always, including FULL) + # In layerwise offload mode, all attention goes through the policy + from nanovllm.kvcache.sparse import create_attention_policy - # Get policy-specific parameters based on type - if config.sparse_policy == SparsePolicyType.XATTN: - policy_kwargs = { - "stride": config.xattn_stride, - "threshold": config.xattn_threshold, - "chunk_size": config.xattn_chunk_size, - "use_triton": config.xattn_use_triton, - "keep_sink": config.xattn_keep_sink, - "keep_recent": config.xattn_keep_recent, - "norm": config.xattn_norm, - } - else: # MINFERENCE or others - policy_kwargs = { - "vertical_size": config.minference_vertical_size, - "slash_size": config.minference_slash_size, - "adaptive_budget": config.minference_adaptive_budget, - "num_sink_tokens": config.minference_num_sink_tokens, - "num_recent_diags": config.minference_num_recent_diags, - } + # Get policy-specific parameters based on type + if config.sparse_policy == SparsePolicyType.XATTN: + policy_kwargs = { + "stride": config.xattn_stride, + "threshold": config.xattn_threshold, + "chunk_size": config.xattn_chunk_size, + "use_triton": config.xattn_use_triton, + "keep_sink": config.xattn_keep_sink, + "keep_recent": config.xattn_keep_recent, + "norm": config.xattn_norm, + "use_bsa": config.xattn_use_bsa, + } + elif config.sparse_policy == SparsePolicyType.MINFERENCE: + policy_kwargs = { + "vertical_size": config.minference_vertical_size, + "slash_size": config.minference_slash_size, + "adaptive_budget": config.minference_adaptive_budget, + "num_sink_tokens": config.minference_num_sink_tokens, + "num_recent_diags": config.minference_num_recent_diags, + } + else: # FULL or QUEST + policy_kwargs = {} - policy = create_sparse_policy(config.sparse_policy, **policy_kwargs) - - # Only use if policy supports sparse prefill - if policy.supports_prefill: - self.sparse_prefill_policy = policy - logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}") + self.attention_policy = create_attention_policy(config.sparse_policy, **policy_kwargs) + logger.info(f"Attention policy: {self.attention_policy}") # Allocate cache through manager self.kvcache_manager.allocate_cache( @@ -395,7 +392,7 @@ class ModelRunner: set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables, - sparse_prefill_policy=self.sparse_prefill_policy) + attention_policy=self.attention_policy) return input_ids, positions def prepare_decode(self, seqs: list[Sequence]): @@ -592,21 +589,11 @@ class ModelRunner: # RoPE q, k = layer.self_attn.rotary_emb(positions, q, k) - # Sparse or Full attention (uses k, v directly - before store!) - if self.sparse_prefill_policy is not None: - attn_output = self.sparse_prefill_policy.sparse_prefill_attention( - q, k, v, layer_id - ) - else: - attn_output = flash_attn_varlen_func( - q, k, v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=total_tokens, - max_seqlen_k=total_tokens, - softmax_scale=layer.self_attn.attn.scale, - causal=True, - ) + # Compute attention using policy (uses k, v directly - before store!) + attn_output = self.attention_policy.compute_prefill( + q, k, v, layer_id, + softmax_scale=layer.self_attn.attn.scale, + ) # O projection attn_output = attn_output.view(total_tokens, -1) @@ -872,23 +859,11 @@ class ModelRunner: # RoPE q, k = layer.self_attn.rotary_emb(positions, q, k) - # Sparse or Full attention - if self.sparse_prefill_policy is not None: - # MInference or other sparse prefill policy - attn_output = self.sparse_prefill_policy.sparse_prefill_attention( - q, k, v, layer_id - ) - else: - # Full attention using FlashAttention - attn_output = flash_attn_varlen_func( - q, k, v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=total_tokens, - max_seqlen_k=total_tokens, - softmax_scale=layer.self_attn.attn.scale, - causal=True, - ) + # Compute attention using policy + attn_output = self.attention_policy.compute_prefill( + q, k, v, layer_id, + softmax_scale=layer.self_attn.attn.scale, + ) # O projection attn_output = attn_output.view(total_tokens, -1) diff --git a/nanovllm/kvcache/sparse/__init__.py b/nanovllm/kvcache/sparse/__init__.py index 4601ccf..d7ed67a 100644 --- a/nanovllm/kvcache/sparse/__init__.py +++ b/nanovllm/kvcache/sparse/__init__.py @@ -1,49 +1,56 @@ """ -Sparse Attention Policy module. +Attention Policy module for layerwise offload mode. -Provides pluggable policies for selecting which KV blocks to load -during chunked attention with CPU offload. +Provides pluggable policies for attention computation: +- FullAttentionPolicy: Standard FlashAttention (no sparsity) +- XAttentionPolicy: Sparse prefill using XAttention algorithm +- MInferencePolicy: MInference sparse attention +- QuestPolicy: Quest block selection (for chunked offload) Usage: - from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType + from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType # Create policy using factory function - policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8) + policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9) + + # Use policy for attention + attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale) # Or create custom policy - class MyPolicy(SparsePolicy): + class MyPolicy(AttentionPolicy): supports_prefill = True supports_decode = True - def select_blocks(self, available_blocks, ctx): - return available_blocks[:5] # Just first 5 blocks + def compute_prefill(self, q, k, v, layer_id, softmax_scale): + # Custom attention computation + ... """ from nanovllm.config import SparsePolicyType -from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext +from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager from nanovllm.kvcache.sparse.minference import MInferencePolicy from nanovllm.kvcache.sparse.xattn import XAttentionPolicy -def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy: +def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy: """ - Create a sparse policy instance from an enum type. + Create an attention policy instance from an enum type. - The returned policy is not yet initialized. Call policy.initialize() - or let the framework call it during KV cache allocation. + All attention (including full attention) goes through a policy in layerwise + offload mode. The policy is responsible for computing prefill/decode attention. Args: - policy_type: SparsePolicyType enum value + policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST) **kwargs: Policy-specific configuration options Returns: - SparsePolicy instance (not initialized) + AttentionPolicy instance Example: - policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4) - policy.initialize(num_layers=28, num_kv_heads=8, ...) + policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9) + attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale) """ if policy_type == SparsePolicyType.FULL: return FullAttentionPolicy() @@ -75,21 +82,32 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic keep_sink=kwargs.get("keep_sink", False), keep_recent=kwargs.get("keep_recent", False), norm=kwargs.get("norm", 1.0), + use_bsa=kwargs.get("use_bsa", True), ) else: raise ValueError(f"Unknown policy type: {policy_type}") +# Backward compatibility alias +create_sparse_policy = create_attention_policy + + __all__ = [ + # New interface + "AttentionPolicy", + "create_attention_policy", + # Backward compatibility "SparsePolicy", + "create_sparse_policy", + # Common types "PolicyContext", "SparsePolicyType", + # Policy implementations "FullAttentionPolicy", "QuestPolicy", "QuestConfig", "BlockMetadataManager", "MInferencePolicy", "XAttentionPolicy", - "create_sparse_policy", ] diff --git a/nanovllm/kvcache/sparse/full_policy.py b/nanovllm/kvcache/sparse/full_policy.py index a17f085..504163d 100644 --- a/nanovllm/kvcache/sparse/full_policy.py +++ b/nanovllm/kvcache/sparse/full_policy.py @@ -1,20 +1,21 @@ """ -Full attention policy - loads all blocks (no sparsity). +Full attention policy - standard FlashAttention without sparsity. This serves as a baseline and default policy when sparse attention is not needed. """ -from typing import List -from .policy import SparsePolicy, PolicyContext +from typing import Optional +import torch +from .policy import AttentionPolicy -class FullAttentionPolicy(SparsePolicy): +class FullAttentionPolicy(AttentionPolicy): """ - Full attention policy that loads all available blocks. + Full attention policy using FlashAttention (no sparsity). - This is the default behavior with no sparsity - all previous - KV cache blocks are loaded for each query chunk. + This is the default behavior with standard causal attention. + All tokens attend to all previous tokens. Use this as: - A baseline for comparing sparse policies @@ -25,15 +26,55 @@ class FullAttentionPolicy(SparsePolicy): # Full attention supports both prefill and decode supports_prefill = True supports_decode = True - requires_block_selection = False # Load all blocks, no selective loading - def select_blocks( + def estimate( self, - available_blocks: List[int], - ctx: PolicyContext, - ) -> List[int]: - """Return all blocks - no sparsity.""" - return available_blocks + q: torch.Tensor, + k: torch.Tensor, + layer_id: int, + ) -> Optional[torch.Tensor]: + """ + Full attention - no sparse mask needed. + + Returns None to indicate full attention should be used. + """ + return None + + def compute_prefill( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + softmax_scale: float, + ) -> torch.Tensor: + """ + Compute full causal attention using FlashAttention. + + Args: + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] + layer_id: Transformer layer index + softmax_scale: Softmax scaling factor (1/sqrt(head_dim)) + + Returns: + Attention output [seq_len, num_heads, head_dim] + """ + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + seq_len = q.shape[0] + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) + + return flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=softmax_scale, + causal=True, + ) def __repr__(self) -> str: return "FullAttentionPolicy()" diff --git a/nanovllm/kvcache/sparse/minference.py b/nanovllm/kvcache/sparse/minference.py index 2f45202..5430bf9 100644 --- a/nanovllm/kvcache/sparse/minference.py +++ b/nanovllm/kvcache/sparse/minference.py @@ -10,10 +10,10 @@ from typing import List, Tuple, Optional import torch import torch.nn.functional as F -from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext +from nanovllm.kvcache.sparse.policy import AttentionPolicy, PolicyContext -class MInferencePolicy(SparsePolicy): +class MInferencePolicy(AttentionPolicy): """ MInference sparse prefill policy using vertical + slash pattern. @@ -347,6 +347,33 @@ class MInferencePolicy(SparsePolicy): return o + def compute_prefill( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + softmax_scale: float, + ) -> torch.Tensor: + """ + Compute MInference sparse prefill attention. + + This is the new unified interface for attention policies. + Delegates to sparse_prefill_attention (ignores softmax_scale as MInference + computes it internally from head_dim). + + Args: + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] + layer_id: Transformer layer index + softmax_scale: Softmax scaling factor (unused, computed internally) + + Returns: + Attention output [seq_len, num_heads, head_dim] + """ + return self.sparse_prefill_attention(q, k, v, layer_id) + def __repr__(self) -> str: return (f"MInferencePolicy(" f"adaptive_budget={self.adaptive_budget}, " diff --git a/nanovllm/kvcache/sparse/policy.py b/nanovllm/kvcache/sparse/policy.py index 74026f7..5757113 100644 --- a/nanovllm/kvcache/sparse/policy.py +++ b/nanovllm/kvcache/sparse/policy.py @@ -1,13 +1,18 @@ """ -Base class for sparse attention policies. +Base class for attention policies in layerwise offload mode. -Sparse attention policies determine which KV cache blocks to load -from CPU for each query chunk during chunked attention computation. +AttentionPolicy defines the interface for all attention computation, +including full attention and sparse attention methods like XAttention. + +Key methods: +- estimate(): Compute sparse attention mask (optional, returns None for full attention) +- compute_prefill(): Compute prefill attention +- compute_decode(): Compute decode attention (default implementation provided) """ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List, Optional, Any +from typing import List, Optional, Tuple import torch # Import SparsePolicyType from config to avoid circular imports @@ -17,10 +22,10 @@ from nanovllm.config import SparsePolicyType @dataclass class PolicyContext: """ - Context passed to sparse policy for block selection. + Context passed to attention policy for block selection. - This dataclass contains all information needed by a sparse policy - to decide which blocks to load for the current query chunk. + This dataclass contains all information needed by an attention policy + for sparse estimation and attention computation. """ query_chunk_idx: int @@ -49,40 +54,41 @@ class PolicyContext: """Total KV sequence length so far (for reference).""" -class SparsePolicy(ABC): +class AttentionPolicy(ABC): """ - Abstract base class for sparse attention policies. + Base class for attention policies in layerwise offload mode. - Subclass this and implement select_blocks() to create custom - sparse attention patterns. The policy receives context about - the current query chunk and returns which KV blocks to load. + All attention computation goes through a policy, including both + full attention and sparse attention methods. + + The policy interface is designed for layerwise offload where: + - The entire KV cache for a layer is on GPU during computation + - No need for block loading from CPU during attention + - estimate() returns a sparse mask (or None for full attention) + - compute_prefill()/compute_decode() perform the actual attention Attributes: supports_prefill: Whether this policy can be used for prefill phase. supports_decode: Whether this policy can be used for decode phase. Example: - class MySparsePolicy(SparsePolicy): - supports_prefill = False # decode-only policy + class MyPolicy(AttentionPolicy): + supports_prefill = True supports_decode = True - def select_blocks(self, available_blocks, ctx): - # Load first block and last 2 blocks - if len(available_blocks) <= 3: - return available_blocks - return [available_blocks[0]] + available_blocks[-2:] + def estimate(self, q, k, layer_id): + # Return sparse mask or None + return None + + def compute_prefill(self, q, k, v, layer_id, softmax_scale): + # Compute attention + return flash_attn_varlen_func(q, k, v, ...) """ # Compatibility flags - override in subclasses supports_prefill: bool = True supports_decode: bool = True - # Whether this policy requires selective block loading during decode - # If True: OffloadEngine will call select_blocks() before loading KV from CPU - # If False: OffloadEngine will load all blocks (select_blocks ignored for load) - # Example: MInference=False (only affects attention), Quest=True (affects load) - requires_block_selection: bool = False - def initialize( self, num_layers: int, @@ -96,7 +102,7 @@ class SparsePolicy(ABC): Initialize policy resources. Called by the framework after KV cache is allocated. Override this - to create metadata structures (e.g., BlockMetadataManager for Quest). + to create metadata structures or pre-allocate buffers. Default implementation does nothing. Args: @@ -109,76 +115,98 @@ class SparsePolicy(ABC): """ pass - @abstractmethod - def select_blocks( + def estimate( self, - available_blocks: List[int], - ctx: PolicyContext, - ) -> List[int]: + q: torch.Tensor, + k: torch.Tensor, + layer_id: int, + ) -> Optional[torch.Tensor]: """ - Select which KV blocks to load for the current query chunk. + Estimate sparse attention mask. - This is the core method that defines the sparse attention pattern. - The returned blocks will be loaded from CPU to GPU for attention - computation against the current query chunk. + For sparse policies (e.g., XAttention), computes block-level importance + and returns a boolean mask indicating which blocks to attend. + For full attention policy, returns None. + + This corresponds to xattn_estimate() in COMPASS. Args: - available_blocks: List of CPU block IDs that contain KV cache - from previous chunks. These are ordered by - their position in the sequence. - ctx: PolicyContext with information about the current query - chunk, layer, phase (prefill/decode), etc. + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + layer_id: Transformer layer index Returns: - List of block IDs to load (must be a subset of available_blocks). - The order may affect performance (sequential access is faster). - Returning [] means no previous blocks will be loaded. + sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask, + or None for full attention """ - pass + return None - def on_prefill_offload( + @abstractmethod + def compute_prefill( self, - cpu_block_id: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, layer_id: int, - k_cache: torch.Tensor, - num_valid_tokens: int, - ) -> None: + softmax_scale: float, + ) -> torch.Tensor: """ - Hook called when a block is offloaded during prefill phase. + Compute prefill attention. - Called BEFORE GPU→CPU copy, while k_cache is still on GPU. - Override this to collect metadata about blocks (e.g., min/max keys - for Quest-style selection). Default implementation does nothing. + The entire KV cache for this layer is on GPU. Compute attention + between Q and K/V, optionally using sparse mask from estimate(). Args: - cpu_block_id: The CPU block ID that will be written + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] layer_id: Transformer layer index - k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU) - num_valid_tokens: Number of valid tokens in this block + softmax_scale: Softmax scaling factor (1/sqrt(head_dim)) + + Returns: + Attention output [seq_len, num_heads, head_dim] """ pass - def on_decode_offload( + def compute_decode( self, - cpu_block_id: int, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, layer_id: int, - k_cache: torch.Tensor, - num_valid_tokens: int, - ) -> None: + softmax_scale: float, + ) -> torch.Tensor: """ - Hook called when a block is offloaded during decode phase. + Compute decode attention. - Called BEFORE GPU→CPU copy, while k_cache is still on GPU. - Override this to update metadata about blocks. Default implementation - does nothing. + KV is provided from ring buffer, containing prefill tokens + decoded tokens. + Default implementation uses FlashAttention. Args: - cpu_block_id: The CPU block ID that will be written + q: Query tensor [1, num_heads, head_dim] + k: Key tensor [context_len+1, num_kv_heads, head_dim] + v: Value tensor [context_len+1, num_kv_heads, head_dim] layer_id: Transformer layer index - k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU) - num_valid_tokens: Number of valid tokens in this block + softmax_scale: Softmax scaling factor + + Returns: + Attention output [1, num_heads, head_dim] """ - pass + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + context_len = k.shape[0] + cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device) + + return flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=1, + max_seqlen_k=context_len, + softmax_scale=softmax_scale, + causal=False, + ) def reset(self) -> None: """ @@ -189,32 +217,9 @@ class SparsePolicy(ABC): """ pass - def sparse_prefill_attention( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - layer_id: int, - ) -> torch.Tensor: - """ - Compute sparse attention for prefill phase. - - This method is called when supports_prefill=True and the policy - is used for GPU-only sparse prefill (no CPU offload). - - Args: - q: Query tensor [seq_len, num_heads, head_dim] - k: Key tensor [seq_len, num_kv_heads, head_dim] - v: Value tensor [seq_len, num_kv_heads, head_dim] - layer_id: Current transformer layer index - - Returns: - Attention output [seq_len, num_heads, head_dim] - """ - raise NotImplementedError( - f"{self.__class__.__name__} does not implement sparse_prefill_attention. " - "Set supports_prefill=False or implement this method." - ) - def __repr__(self) -> str: return f"{self.__class__.__name__}()" + + +# Backward compatibility alias +SparsePolicy = AttentionPolicy diff --git a/nanovllm/kvcache/sparse/quest.py b/nanovllm/kvcache/sparse/quest.py index 71c9063..4709a55 100644 --- a/nanovllm/kvcache/sparse/quest.py +++ b/nanovllm/kvcache/sparse/quest.py @@ -11,7 +11,7 @@ import logging import torch from dataclasses import dataclass from typing import List, Tuple, Optional -from .policy import SparsePolicy, PolicyContext +from .policy import AttentionPolicy, PolicyContext logger = logging.getLogger(__name__) @@ -137,7 +137,7 @@ class QuestConfig: """Always include this many recent blocks (last N blocks), in addition to Top-K.""" -class QuestPolicy(SparsePolicy): +class QuestPolicy(AttentionPolicy): """ Quest-style Top-K block selection using min/max key bounds. @@ -317,6 +317,25 @@ class QuestPolicy(SparsePolicy): if self.metadata is not None: self.metadata.reset() + def compute_prefill( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + layer_id: int, + softmax_scale: float, + ) -> torch.Tensor: + """ + Quest does not support prefill - raises error. + + Quest is a decode-only policy for selective block loading. + For prefill, use FullAttentionPolicy or XAttentionPolicy. + """ + raise NotImplementedError( + "QuestPolicy does not support prefill. " + "Use FullAttentionPolicy or XAttentionPolicy for prefill." + ) + def __repr__(self) -> str: return ( f"QuestPolicy(topk={self.config.topk_blocks}, " diff --git a/nanovllm/kvcache/sparse/xattn.py b/nanovllm/kvcache/sparse/xattn.py index 48ead2f..d7681c4 100644 --- a/nanovllm/kvcache/sparse/xattn.py +++ b/nanovllm/kvcache/sparse/xattn.py @@ -4,48 +4,56 @@ XAttention sparse attention policy for nano-vllm. Implements the XAttention algorithm from COMPASS, using chunked estimation and block sparse attention for efficient long-context inference. +Architecture: + XAttention = Estimate (Triton) + Compute (BSA) + - Estimate: xattn_estimate() computes block-level importance scores + - Compute: block_sparse_attn_func() executes sparse attention + Reference: COMPASS/compass/src/Xattention.py """ import math -from typing import List, Optional +from typing import Optional import torch import torch.nn.functional as F -from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext -from nanovllm.kvcache.sparse.kernels import ( - flat_group_gemm_fuse_reshape, - softmax_fuse_block_sum, -) -from nanovllm.kvcache.sparse.utils import find_blocks_chunked +from nanovllm.kvcache.sparse.policy import AttentionPolicy + +# BSA block size is fixed at 128 (hardcoded in block_sparse_attn) +BSA_BLOCK_SIZE = 128 -class XAttentionPolicy(SparsePolicy): +class XAttentionPolicy(AttentionPolicy): """ XAttention sparse prefill policy using chunked estimation + block sparse attention. This policy estimates sparse attention patterns by: - 1. Chunked QK computation using Triton kernels + 1. Chunked QK computation using Triton kernels (via nanovllm.ops.xattn) 2. Block-wise softmax with importance scores 3. Block selection based on threshold - 4. Block sparse attention computation + 4. Block sparse attention computation using MIT-HAN-LAB BSA library + + The key method is estimate() which calls xattn_estimate() from nanovllm.ops + to compute the sparse attention mask. Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.) + BSA library: https://github.com/mit-han-lab/Block-Sparse-Attention """ supports_prefill = True - supports_decode = False # XAttention is prefill-only - requires_block_selection = False # Only affects attention computation + supports_decode = True # Uses default FlashAttention for decode def __init__( self, stride: int = 8, threshold: float = 0.9, - chunk_size: Optional[int] = None, + block_size: int = 128, + chunk_size: int = 16384, use_triton: bool = True, keep_sink: bool = False, keep_recent: bool = False, norm: float = 1.0, + use_bsa: bool = True, ): """ Initialize XAttention policy. @@ -53,19 +61,28 @@ class XAttentionPolicy(SparsePolicy): Args: stride: Stride for reorganizing Q/K (default: 8) threshold: Block selection threshold, 0-1 (default: 0.9) - chunk_size: Chunk size for estimation (auto if None) + block_size: Block size for sparse attention (default: 128, must match BSA) + chunk_size: Chunk size for estimation (default: 16384) use_triton: Use Triton kernels (requires SM 80+) keep_sink: Always keep first block (sink tokens) keep_recent: Always keep recent diagonal blocks norm: Normalization factor for attention scores + use_bsa: Use Block Sparse Attention library (default: True) """ self.stride = stride self.threshold = threshold + self.block_size = block_size self.chunk_size = chunk_size self.use_triton = use_triton self.keep_sink = keep_sink self.keep_recent = keep_recent self.norm = norm + self.use_bsa = use_bsa + + # BSA requires block_size = 128 + if self.use_bsa and self.block_size != BSA_BLOCK_SIZE: + print(f"XAttention: BSA requires block_size=128, adjusting from {self.block_size}") + self.block_size = BSA_BLOCK_SIZE # Check Triton availability if self.use_triton: @@ -79,379 +96,206 @@ class XAttentionPolicy(SparsePolicy): self.use_triton = False print("XAttention: Triton not available. Falling back to PyTorch.") - def select_blocks( + # Check BSA availability + if self.use_bsa: + try: + from block_sparse_attn import block_sparse_attn_func + except ImportError: + self.use_bsa = False + print("XAttention: block_sparse_attn not available. Falling back to FlashAttention.") + + def estimate( self, - available_blocks: List[int], - ctx: PolicyContext, - ) -> List[int]: + q: torch.Tensor, + k: torch.Tensor, + layer_id: int, + ) -> Optional[torch.Tensor]: """ - Select blocks for decode phase. + Estimate sparse attention mask using XAttention algorithm. - XAttention is prefill-only, so this method is only used as a fallback. - Returns all available blocks by default. + Calls xattn_estimate() from nanovllm.ops.xattn to compute block-level + importance scores and generate a sparse boolean mask. + + Args: + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + layer_id: Transformer layer index + + Returns: + sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask, + or None if estimation fails (fallback to full attention) """ - # XAttention is prefill-only, but we need to implement this abstract method - # Since requires_block_selection=False, this won't be called for loading - return available_blocks + try: + from nanovllm.ops.xattn import xattn_estimate - def sparse_prefill_attention( + seq_len, num_heads, head_dim = q.shape + num_kv_heads = k.shape[1] + + # Convert to [batch, heads, seq, dim] format expected by xattn_estimate + q_bhsd = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim] + k_bhsd = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, seq_len, head_dim] + + # Handle GQA: expand k to match q heads for estimation + if num_kv_heads != num_heads: + # GQA: expand k by repeating + repeat_factor = num_heads // num_kv_heads + k_bhsd = k_bhsd.repeat(1, repeat_factor, 1, 1) + + # Call xattn_estimate + attn_sums, sparse_mask = xattn_estimate( + q_bhsd, k_bhsd, + block_size=self.block_size, + stride=self.stride, + norm=self.norm, + threshold=self.threshold, + chunk_size=self.chunk_size, + use_triton=self.use_triton, + causal=True, + keep_sink=self.keep_sink, + keep_recent=self.keep_recent, + ) + + return sparse_mask + + except Exception as e: + # If estimation fails, return None to use full attention + print(f"XAttention estimate failed: {e}, falling back to full attention") + return None + + def compute_prefill( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, layer_id: int, + softmax_scale: float, ) -> torch.Tensor: """ - Compute XAttention sparse attention for prefill. + Compute XAttention sparse prefill attention. + + Flow: + 1. Call estimate() to get sparse mask + 2. If mask is None or BSA unavailable, use full FlashAttention + 3. Otherwise, use block_sparse_attn_func with mask Args: q: Query tensor [seq_len, num_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim] v: Value tensor [seq_len, num_kv_heads, head_dim] - layer_id: Current transformer layer index + layer_id: Transformer layer index + softmax_scale: Softmax scaling factor Returns: Attention output [seq_len, num_heads, head_dim] """ - seq_len = q.shape[0] - num_heads = q.shape[1] - head_dim = q.shape[2] - num_kv_heads = k.shape[1] + # If BSA is disabled, use full attention directly (skip estimation) + if not self.use_bsa: + return self._full_attention(q, k, v, softmax_scale) - # Use FlashAttention directly for CPU offload mode - # FlashAttention supports GQA natively - try: - from flash_attn.flash_attn_interface import flash_attn_varlen_func + # Step 1: Estimate sparse mask + sparse_mask = self.estimate(q, k, layer_id) - cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) + # Step 2: Compute attention + if sparse_mask is None: + # Estimation failed, fallback to full FlashAttention + return self._full_attention(q, k, v, softmax_scale) - attn_output = flash_attn_varlen_func( - q, k, v, - cu_seqlens_q=cu_seqlens, - cu_seqlens_k=cu_seqlens, - max_seqlen_q=seq_len, - max_seqlen_k=seq_len, - softmax_scale=1.0 / math.sqrt(head_dim), - causal=True, - ) + # Use block sparse attention with mask + return self._block_sparse_attention(q, k, v, sparse_mask, softmax_scale) - return attn_output - - except Exception as e: - # Fallback: PyTorch SDPA (supports GQA natively) - print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA") - attn_output = F.scaled_dot_product_attention( - q, k, v, - attn_mask=None, - is_causal=True, - scale=1.0 / math.sqrt(head_dim) - ) - return attn_output - - def _xattn_offload_prefill( + def _block_sparse_attention( self, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - causal: bool = True, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + sparse_mask: torch.Tensor, + softmax_scale: float, ) -> torch.Tensor: """ - Simplified XAttention prefill for CPU offload mode. - - Uses FlashAttention with full context since chunked estimation - with full key_states requires special handling. - """ - batch_size, num_heads, q_len, head_dim = query_states.shape - _, _, k_len, _ = key_states.shape - - # Use FlashAttention with full context - # In offload mode, keys are already on CPU and loaded as needed - try: - from flash_attn.flash_attn_interface import flash_attn_varlen_func - - # Convert to [seq, heads, dim] format - q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim] - k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim] - v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim] - - cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device) - cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device) - - attn_output = flash_attn_varlen_func( - q, k, v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_len, - max_seqlen_k=k_len, - softmax_scale=1.0 / math.sqrt(head_dim), - causal=causal, - ) - - # Convert back to [batch, seq, heads, dim] - attn_output = attn_output.unsqueeze(0).transpose(1, 2) # [1, q_len, num_heads, head_dim] - - return attn_output - - except Exception as e: - # Final fallback: PyTorch SDPA - print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA") - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): - attn_output = F.scaled_dot_product_attention( - query_states, key_states, value_states, - attn_mask=None, - is_causal=causal, - scale=1.0 / math.sqrt(head_dim) - ) - return attn_output - - def _xattn_prefill( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - stride: int, - norm: float, - threshold: float, - block_size: int = 128, - use_triton: bool = True, - causal: bool = True, - chunk_size: Optional[int] = None, - keep_sink: bool = False, - keep_recent: bool = False, - ) -> torch.Tensor: - """ - XAttention prefill implementation. + Compute block sparse attention using MIT-HAN-LAB BSA library. Args: - query_states: [batch, num_heads, q_len, head_dim] - key_states: [batch, num_heads, k_len, head_dim] - value_states: [batch, num_heads, k_len, head_dim] - ... other params + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] + sparse_mask: Block mask [batch, num_heads, q_blocks, k_blocks] + softmax_scale: Softmax scaling factor Returns: - Attention output [batch, q_len, num_heads, head_dim] + Attention output [seq_len, num_heads, head_dim] """ - batch_size, num_heads, k_len, head_dim = key_states.shape - _, _, q_len, _ = query_states.shape + from block_sparse_attn import block_sparse_attn_func - # Auto-compute chunk_size if not specified - if chunk_size is None: - chunk_size = int( - max( - min( - max(2048, 1 << (k_len - 1).bit_length()), - 128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()), - ), - 2048, - ) - ) + seq_len, num_heads, head_dim = q.shape + num_kv_heads = k.shape[1] - # Phase 1: Estimate sparse pattern - attn_sums, approx_simple_mask = self._xattn_estimate( - query_states, - key_states, - block_size=block_size, - stride=stride, - norm=norm, - threshold=threshold, - chunk_size=chunk_size, - use_triton=use_triton, - causal=causal, - keep_sink=keep_sink, - keep_recent=keep_recent, - ) + # Handle GQA: expand K/V to match Q heads + if num_kv_heads != num_heads: + repeat_factor = num_heads // num_kv_heads + k = k.repeat_interleave(repeat_factor, dim=1) + v = v.repeat_interleave(repeat_factor, dim=1) - # Phase 2: Block sparse attention - # For now, use FlashAttention as fallback since block_sparse_attn_func may not be available - attn_output = self._block_sparse_attention_fallback( - query_states, key_states, value_states, - approx_simple_mask, block_size, q_len, k_len + # Cumulative sequence lengths (batch=1) + cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) + cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) + + # Head mask type: 1 for all heads using block sparse + head_mask_type = torch.ones(num_heads, dtype=torch.int32, device=q.device) + + # Trim sparse_mask to actual block counts + q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE + k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE + block_mask = sparse_mask[:, :, :q_blocks, :k_blocks].contiguous() + + # Call BSA + attn_output = block_sparse_attn_func( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + head_mask_type, + None, # streaming_info (left_mask) + block_mask, + seq_len, seq_len, + p_dropout=0.0, + deterministic=True, + softmax_scale=softmax_scale, + is_causal=True, ) return attn_output - def _xattn_estimate( + def _full_attention( self, - query_states: torch.Tensor, - key_states: torch.Tensor, - block_size: int, - stride: int, - norm: float = 1, - softmax: bool = True, - threshold: float = 0.9, - chunk_size: int = 16384, - use_triton: bool = True, - causal: bool = True, - keep_sink: bool = False, - keep_recent: bool = False, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + softmax_scale: float, ) -> torch.Tensor: """ - Estimate sparse attention pattern using chunked computation. + Compute full causal attention using FlashAttention. + + Args: + q: Query tensor [seq_len, num_heads, head_dim] + k: Key tensor [seq_len, num_kv_heads, head_dim] + v: Value tensor [seq_len, num_kv_heads, head_dim] + softmax_scale: Softmax scaling factor Returns: - attn_sums: [batch, heads, q_blocks, k_blocks] - importance scores - simple_masks: [batch, heads, q_blocks, k_blocks] - boolean masks + Attention output [seq_len, num_heads, head_dim] """ - batch_size, num_kv_head, k_len, head_dim = key_states.shape - batch_size, num_q_head, q_len, head_dim = query_states.shape + from flash_attn.flash_attn_interface import flash_attn_varlen_func - k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len - q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len - k_chunk_num = (k_len + k_num_to_pad) // chunk_size - k_block_num = (k_len + k_num_to_pad) // block_size - q_chunk_num = (q_len + q_num_to_pad) // chunk_size - q_block_num = (q_len + q_num_to_pad) // block_size + seq_len = q.shape[0] + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) - # Pad inputs - if k_num_to_pad > 0: - pad_key_states = F.pad(key_states, (0, 0, 0, k_num_to_pad), value=0) - else: - pad_key_states = key_states - if q_num_to_pad > 0: - pad_query_states = F.pad(query_states, (0, 0, 0, q_num_to_pad), value=0) - else: - pad_query_states = query_states - - reshaped_chunk_size = chunk_size // stride - reshaped_block_size = block_size // stride - k_reshaped_seq_len = (k_len + k_num_to_pad) // stride - - attn_sum_list = [] - simple_mask_list = [] - - for chunk_idx in range(q_chunk_num): - if use_triton: - # Triton GEMM + Softmax - attn_weights_slice = flat_group_gemm_fuse_reshape( - pad_query_states[:, :, (chunk_idx * reshaped_chunk_size) * stride : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size) * stride, :], - pad_key_states, - stride, - (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size, - (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size, - is_causal=causal, - ) - - attn_sum = softmax_fuse_block_sum( - attn_weights_slice, - reshaped_block_size, - min(4096, reshaped_block_size), - (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size, - (k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size, - k_reshaped_seq_len - (k_num_to_pad // stride), - 1.4426950408889634 / math.sqrt(head_dim) / stride / norm, - is_causal=causal, - ) - else: - # PyTorch fallback - chunk_size_actual = reshaped_chunk_size - chunk_start = chunk_idx * chunk_size_actual - chunk_end = chunk_start + chunk_size_actual - - chunked_query = pad_query_states[:, :, chunk_start * stride:chunk_end * stride:stride, :] - attn_weights_slice = torch.matmul(chunked_query, pad_key_states.transpose(2, 3)) - attn_weights_slice = attn_weights_slice / math.sqrt(head_dim) / stride / norm - - if causal: - causal_mask = torch.zeros((batch_size, num_q_head, chunk_size_actual, chunk_size_actual * k_chunk_num), device=key_states.device) - causal_mask[:, :, :, -(k_num_to_pad // stride):] = float("-inf") - # ... more causal mask logic ... - attn_weights_slice = attn_weights_slice + causal_mask - - attn_weights_slice = F.softmax(attn_weights_slice, dim=-1, dtype=torch.float32) - attn_sum = attn_weights_slice.view(batch_size, num_q_head, chunk_size_actual // reshaped_block_size, reshaped_block_size, -1).sum(dim=-1).sum(dim=-2) - - # Find blocks based on threshold - simple_mask = find_blocks_chunked( - attn_sum, - k_block_num - q_block_num + chunk_idx * (reshaped_chunk_size // reshaped_block_size), - threshold, - None, - decoding=False, - mode="prefill", - causal=causal, - ) - - attn_sum_list.append(attn_sum) - simple_mask_list.append(simple_mask) - - attn_sums = torch.cat(attn_sum_list, dim=-2) - simple_masks = torch.cat(simple_mask_list, dim=-2) - - # Apply causal mask to block masks - if causal: - simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( - torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=key_states.device), diagonal=0), - simple_masks[:, :, -q_block_num:, -q_block_num:], - False, - ) - - if keep_sink: - simple_masks[:, :, 0, :] = True - - if keep_recent: - eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool) - eye_matrix_expanded = eye_matrix.unsqueeze(0).unsqueeze(0).expand(1, num_q_head, q_block_num, q_block_num) - simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where( - eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:] - ) - - return attn_sums, simple_masks - - def _block_sparse_attention_fallback( - self, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - mask: torch.Tensor, - block_size: int, - q_len: int, - k_len: int, - ) -> torch.Tensor: - """ - Fallback implementation using FlashAttention. - - Since block_sparse_attn_func may not be available in all environments, - this uses standard FlashAttention with full attention. - """ - try: - from flash_attn.flash_attn_interface import flash_attn_varlen_func - - batch_size, num_heads, _, head_dim = query_states.shape - - # Convert to [seq, heads, dim] format - q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim] - k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim] - v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim] - - cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device) - cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device) - - attn_output = flash_attn_varlen_func( - q, k, v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_len, - max_seqlen_k=k_len, - softmax_scale=1.0 / math.sqrt(head_dim), - causal=True, - ) - - # Convert back to [batch, seq, heads, dim] - attn_output = attn_output.unsqueeze(0).transpose(1, 2) - - return attn_output - - except Exception as e: - # Final fallback: PyTorch SDPA - print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA") - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): - attn_output = F.scaled_dot_product_attention( - query_states, key_states, value_states, - attn_mask=None, - is_causal=True, - scale=1.0 / math.sqrt(query_states.shape[-1]) - ) - return attn_output + return flash_attn_varlen_func( + q, k, v, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=seq_len, + max_seqlen_k=seq_len, + softmax_scale=softmax_scale, + causal=True, + ) def reset(self) -> None: """Reset policy state (no state to reset for XAttention).""" @@ -461,4 +305,6 @@ class XAttentionPolicy(SparsePolicy): return (f"XAttentionPolicy(" f"stride={self.stride}, " f"threshold={self.threshold}, " - f"use_triton={self.use_triton})") + f"block_size={self.block_size}, " + f"use_triton={self.use_triton}, " + f"use_bsa={self.use_bsa})") diff --git a/nanovllm/layers/attention.py b/nanovllm/layers/attention.py index b9b4b8d..6e47d06 100644 --- a/nanovllm/layers/attention.py +++ b/nanovllm/layers/attention.py @@ -98,10 +98,10 @@ class Attention(nn.Module): max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, softmax_scale=self.scale, causal=True, block_table=context.block_tables) - elif context.sparse_prefill_policy is not None: - # Sparse prefill (GPU-only) - delegate to policy - o = context.sparse_prefill_policy.sparse_prefill_attention( - q, k, v, self.layer_id + elif context.attention_policy is not None: + # Attention via policy (GPU-only) - delegate to policy + o = context.attention_policy.compute_prefill( + q, k, v, self.layer_id, softmax_scale=self.scale ) else: o = flash_attn_varlen_func(q, k, v, diff --git a/nanovllm/utils/context.py b/nanovllm/utils/context.py index 77e571f..4bb738e 100644 --- a/nanovllm/utils/context.py +++ b/nanovllm/utils/context.py @@ -14,9 +14,9 @@ class Context: context_lens: torch.Tensor | None = None block_tables: torch.Tensor | None = None - # Sparse prefill attention support (GPU-only path) - # When set, uses policy.sparse_prefill_attention() instead of FlashAttention - sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True + # Attention policy support (GPU-only path) + # When set, uses policy.compute_prefill() instead of FlashAttention + attention_policy: Any = None # AttentionPolicy instance _CONTEXT = Context() @@ -35,7 +35,7 @@ def set_context( slot_mapping=None, context_lens=None, block_tables=None, - sparse_prefill_policy=None, + attention_policy=None, ): global _CONTEXT _CONTEXT = Context( @@ -47,7 +47,7 @@ def set_context( slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables, - sparse_prefill_policy=sparse_prefill_policy, + attention_policy=attention_policy, ) diff --git a/tests/test_needle.py b/tests/test_needle.py index 4e2b3c0..1c95bb2 100644 --- a/tests/test_needle.py +++ b/tests/test_needle.py @@ -32,11 +32,14 @@ def run_needle_test( enable_cpu_offload: bool = False, enable_quest: bool = False, enable_minference: bool = False, + enable_xattn: bool = False, sparse_topk: int = 8, sparse_threshold: int = 4, minference_budget: float = 0.3, minference_vertical: int = 1000, minference_slash: int = 6096, + xattn_threshold: float = 0.9, + xattn_use_bsa: bool = True, gpu_utilization: float = 0.9, enforce_eager: bool = True, verbose: bool = True, @@ -56,11 +59,14 @@ def run_needle_test( enable_cpu_offload: Enable CPU offload mode enable_quest: Enable Quest sparse attention (decode-only Top-K) enable_minference: Enable MInference sparse prefill (GPU-only) + enable_xattn: Enable XAttention sparse prefill with BSA sparse_topk: Top-K blocks for Quest sparse_threshold: Apply sparse only when blocks > threshold minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode) minference_vertical: Fixed vertical_size (only used when budget=None) minference_slash: Fixed slash_size (only used when budget=None) + xattn_threshold: XAttention block selection threshold (0-1) + xattn_use_bsa: Use Block Sparse Attention library gpu_utilization: GPU memory utilization fraction verbose: Print detailed output @@ -68,7 +74,9 @@ def run_needle_test( True if test passed, False otherwise """ # Determine sparse policy - if enable_minference: + if enable_xattn: + sparse_policy = SparsePolicyType.XATTN + elif enable_minference: sparse_policy = SparsePolicyType.MINFERENCE elif enable_quest: sparse_policy = SparsePolicyType.QUEST @@ -94,6 +102,8 @@ def run_needle_test( print(f" MInference: adaptive (budget={minference_budget})") else: print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})") + if enable_xattn: + print(f" XAttention: threshold={xattn_threshold}, use_bsa={xattn_use_bsa}") print(f"{'='*60}\n") # 1. Initialize LLM @@ -111,7 +121,7 @@ def run_needle_test( llm_kwargs["sparse_threshold_blocks"] = sparse_threshold # Set sparse policy (can be used with or without offload) - if enable_minference or enable_quest: + if enable_minference or enable_quest or enable_xattn: llm_kwargs["sparse_policy"] = sparse_policy # MInference params (works with both GPU-only and offload mode) @@ -120,6 +130,11 @@ def run_needle_test( llm_kwargs["minference_vertical_size"] = minference_vertical llm_kwargs["minference_slash_size"] = minference_slash + # XAttention params + if enable_xattn: + llm_kwargs["xattn_threshold"] = xattn_threshold + llm_kwargs["xattn_use_bsa"] = xattn_use_bsa + llm = LLM(model_path, **llm_kwargs) # 2. Generate needle prompt @@ -224,6 +239,11 @@ if __name__ == "__main__": action="store_true", help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)" ) + parser.add_argument( + "--enable-xattn", + action="store_true", + help="Enable XAttention sparse prefill with Block Sparse Attention" + ) parser.add_argument( "--sparse-topk", type=int, @@ -254,6 +274,17 @@ if __name__ == "__main__": default=6096, help="Fixed slash_size (only used when budget=0)" ) + parser.add_argument( + "--xattn-threshold", + type=float, + default=0.9, + help="XAttention block selection threshold (0-1, higher=more blocks)" + ) + parser.add_argument( + "--xattn-no-bsa", + action="store_true", + help="Disable Block Sparse Attention (use FlashAttention fallback)" + ) parser.add_argument( "--gpu-utilization", type=float, @@ -291,11 +322,14 @@ if __name__ == "__main__": enable_cpu_offload=args.enable_offload, enable_quest=args.enable_quest, enable_minference=args.enable_minference, + enable_xattn=args.enable_xattn, sparse_topk=args.sparse_topk, sparse_threshold=args.sparse_threshold, minference_budget=minference_budget, minference_vertical=args.minference_vertical, minference_slash=args.minference_slash, + xattn_threshold=args.xattn_threshold, + xattn_use_bsa=not args.xattn_no_bsa, gpu_utilization=args.gpu_utilization, enforce_eager=enforce_eager, verbose=True,