feat: add XAttention sparse policy integration
Integrate COMPASS XAttention algorithm into nano-vllm's CPU offload execution path. Uses FlashAttention with native GQA support for offload mode. New files: - nanovllm/kvcache/sparse/utils.py: find_blocks_chunked() utility - nanovllm/kvcache/sparse/kernels.py: Triton kernels for XAttention - nanovllm/kvcache/sparse/xattn.py: XAttentionPolicy implementation Modified: - nanovllm/config.py: Add XATTN configuration parameters - nanovllm/engine/model_runner.py: Support XATTN policy - nanovllm/kvcache/sparse/__init__.py: Register XAttentionPolicy - tests/test_ruler.py: Add --sparse-policy parameter Test results (32k ruler): - NIAH tasks: 12/12 (100%) - QA/Recall tasks: 11/15 (73%) - Overall: 23/27 (85%) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -10,6 +10,7 @@ class SparsePolicyType(Enum):
|
||||
FULL = auto() # No sparse attention (load all blocks)
|
||||
QUEST = auto() # Query-aware Top-K block selection (decode only)
|
||||
MINFERENCE = auto() # MInference vertical + slash sparse prefill (GPU-only)
|
||||
XATTN = auto() # XAttention chunked estimation + block-sparse attention
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -53,6 +54,15 @@ class Config:
|
||||
minference_num_sink_tokens: int = 30 # Sink tokens to always keep
|
||||
minference_num_recent_diags: int = 100 # Recent diagonals to always keep
|
||||
|
||||
# XAttention configuration (used when sparse_policy == XATTN)
|
||||
xattn_stride: int = 8 # Stride for reorganizing Q/K
|
||||
xattn_threshold: float = 0.9 # Block selection threshold (0-1)
|
||||
xattn_chunk_size: int = 16384 # Chunk size for estimation (auto if None)
|
||||
xattn_use_triton: bool = True # Use Triton kernels (requires SM 80+)
|
||||
xattn_keep_sink: bool = False # Always keep first block (sink tokens)
|
||||
xattn_keep_recent: bool = False # Always keep recent diagonal blocks
|
||||
xattn_norm: float = 1.0 # Normalization factor for attention scores
|
||||
|
||||
def __post_init__(self):
|
||||
assert os.path.isdir(self.model)
|
||||
assert self.kvcache_block_size % 256 == 0
|
||||
|
||||
@@ -178,19 +178,34 @@ class ModelRunner:
|
||||
# Create KV cache manager using factory
|
||||
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
|
||||
|
||||
# Create sparse prefill policy for GPU-only path
|
||||
# This is separate from CPU offload sparse policy (which uses select_blocks)
|
||||
# Create sparse prefill policy
|
||||
# This is used for both GPU-only and CPU offload modes when policy supports prefill
|
||||
self.sparse_prefill_policy = None
|
||||
if not config.enable_cpu_offload and config.sparse_policy != SparsePolicyType.FULL:
|
||||
if config.sparse_policy != SparsePolicyType.FULL:
|
||||
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||
policy = create_sparse_policy(
|
||||
config.sparse_policy,
|
||||
vertical_size=config.minference_vertical_size,
|
||||
slash_size=config.minference_slash_size,
|
||||
adaptive_budget=config.minference_adaptive_budget,
|
||||
num_sink_tokens=config.minference_num_sink_tokens,
|
||||
num_recent_diags=config.minference_num_recent_diags,
|
||||
)
|
||||
|
||||
# Get policy-specific parameters based on type
|
||||
if config.sparse_policy == SparsePolicyType.XATTN:
|
||||
policy_kwargs = {
|
||||
"stride": config.xattn_stride,
|
||||
"threshold": config.xattn_threshold,
|
||||
"chunk_size": config.xattn_chunk_size,
|
||||
"use_triton": config.xattn_use_triton,
|
||||
"keep_sink": config.xattn_keep_sink,
|
||||
"keep_recent": config.xattn_keep_recent,
|
||||
"norm": config.xattn_norm,
|
||||
}
|
||||
else: # MINFERENCE or others
|
||||
policy_kwargs = {
|
||||
"vertical_size": config.minference_vertical_size,
|
||||
"slash_size": config.minference_slash_size,
|
||||
"adaptive_budget": config.minference_adaptive_budget,
|
||||
"num_sink_tokens": config.minference_num_sink_tokens,
|
||||
"num_recent_diags": config.minference_num_recent_diags,
|
||||
}
|
||||
|
||||
policy = create_sparse_policy(config.sparse_policy, **policy_kwargs)
|
||||
|
||||
# Only use if policy supports sparse prefill
|
||||
if policy.supports_prefill:
|
||||
self.sparse_prefill_policy = policy
|
||||
|
||||
@@ -24,6 +24,7 @@ from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||
from nanovllm.kvcache.sparse.minference import MInferencePolicy
|
||||
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
|
||||
|
||||
|
||||
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||
@@ -65,6 +66,17 @@ def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolic
|
||||
num_recent_diags=kwargs.get("num_recent_diags", 100),
|
||||
)
|
||||
|
||||
elif policy_type == SparsePolicyType.XATTN:
|
||||
return XAttentionPolicy(
|
||||
stride=kwargs.get("stride", 8),
|
||||
threshold=kwargs.get("threshold", 0.9),
|
||||
chunk_size=kwargs.get("chunk_size", 16384),
|
||||
use_triton=kwargs.get("use_triton", True),
|
||||
keep_sink=kwargs.get("keep_sink", False),
|
||||
keep_recent=kwargs.get("keep_recent", False),
|
||||
norm=kwargs.get("norm", 1.0),
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown policy type: {policy_type}")
|
||||
|
||||
@@ -78,5 +90,6 @@ __all__ = [
|
||||
"QuestConfig",
|
||||
"BlockMetadataManager",
|
||||
"MInferencePolicy",
|
||||
"XAttentionPolicy",
|
||||
"create_sparse_policy",
|
||||
]
|
||||
|
||||
320
nanovllm/kvcache/sparse/kernels.py
Normal file
320
nanovllm/kvcache/sparse/kernels.py
Normal file
@@ -0,0 +1,320 @@
|
||||
"""
|
||||
Triton kernels for XAttention sparse attention.
|
||||
|
||||
Copied and adapted from COMPASS/compass/src/kernels.py
|
||||
for XAttention integration in nano-vllm.
|
||||
|
||||
Requirements:
|
||||
- Triton >= 2.1.0
|
||||
- CUDA compute capability SM 80+ (RTX 3090, A100, H100, etc.)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import math
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_fuse_block_sum_kernel_causal(
|
||||
In,
|
||||
Out,
|
||||
scale,
|
||||
input_stride_0,
|
||||
input_stride_1,
|
||||
input_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
output_stride_2,
|
||||
real_q_len,
|
||||
k_len,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
segment_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
block_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
batch_id = tl.program_id(2)
|
||||
|
||||
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||
offs_k = tl.arange(0, segment_size)
|
||||
|
||||
num_iters = k_len // segment_size
|
||||
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size
|
||||
|
||||
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0
|
||||
|
||||
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||
|
||||
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
||||
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
||||
|
||||
for iter in range(0, num_iters_before_causal):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
m_local = tl.max(X, 1)
|
||||
m_new = tl.maximum(m_i, m_local)
|
||||
alpha = tl.math.exp2(m_i - m_new)
|
||||
|
||||
X = X - m_new[:, None]
|
||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||
l_i = l_i * alpha + l_local
|
||||
|
||||
m_i = m_new
|
||||
|
||||
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
|
||||
X = tl.where(mask, X, -1.0e6)
|
||||
m_local = tl.max(X, 1)
|
||||
m_new = tl.maximum(m_i, m_local)
|
||||
alpha = tl.math.exp2(m_i - m_new)
|
||||
|
||||
X = X - m_new[:, None]
|
||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||
l_i = l_i * alpha + l_local
|
||||
|
||||
m_i = m_new
|
||||
|
||||
l_i_inv = 1.0 / l_i
|
||||
|
||||
sum_mask = offs_q[:, None] < real_q_len
|
||||
|
||||
for iter in range(0, num_iters_before_causal):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||
X = tl.where(sum_mask, X, 0)
|
||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||
X = tl.sum(X, 2)
|
||||
X = tl.sum(X, 0)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
|
||||
X = tl.where(mask, X, -1.0e6)
|
||||
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||
X = tl.where(sum_mask, X, 0)
|
||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||
X = tl.sum(X, 2)
|
||||
X = tl.sum(X, 0)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
for iter in range(num_iters_before_causal + 1, num_iters):
|
||||
X = tl.zeros([segment_size // block_size], dtype=tl.float32)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_fuse_block_sum_kernel_non_causal(
|
||||
In,
|
||||
Out,
|
||||
scale,
|
||||
input_stride_0,
|
||||
input_stride_1,
|
||||
input_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
output_stride_2,
|
||||
real_q_len,
|
||||
k_len,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
segment_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
block_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
batch_id = tl.program_id(2)
|
||||
|
||||
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||
offs_k = tl.arange(0, segment_size)
|
||||
|
||||
num_iters = k_len // segment_size
|
||||
|
||||
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0
|
||||
|
||||
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||
|
||||
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
||||
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
||||
|
||||
for iter in range(0, num_iters):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
m_local = tl.max(X, 1)
|
||||
m_new = tl.maximum(m_i, m_local)
|
||||
alpha = tl.math.exp2(m_i - m_new)
|
||||
|
||||
X = X - m_new[:, None]
|
||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||
l_i = l_i * alpha + l_local
|
||||
|
||||
m_i = m_new
|
||||
|
||||
l_i_inv = 1.0 / l_i
|
||||
|
||||
sum_mask = offs_q[:, None] < real_q_len
|
||||
|
||||
for iter in range(0, num_iters):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||
X = tl.where(sum_mask, X, 0)
|
||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||
X = tl.sum(X, 2)
|
||||
X = tl.sum(X, 0)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def flat_group_gemm_fuse_reshape_kernel(Q, K, Out,
|
||||
stride_qz, stride_qh, stride_qn,
|
||||
stride_kz, stride_kh, stride_kn,
|
||||
stride_oz, stride_oh, stride_on,
|
||||
chunk_start, chunk_end,
|
||||
H: tl.constexpr,
|
||||
STRIDE: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
is_causal: tl.constexpr,
|
||||
):
|
||||
block_m = tl.program_id(0).to(tl.int64)
|
||||
block_n = tl.program_id(1).to(tl.int64)
|
||||
batch_id = tl.program_id(2).to(tl.int64) // H
|
||||
head_id = tl.program_id(2).to(tl.int64) % H
|
||||
|
||||
if is_causal:
|
||||
if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N:
|
||||
return
|
||||
|
||||
Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn
|
||||
K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn
|
||||
|
||||
Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + tl.arange(0, HEAD_DIM)[None, :] + stride_qn * (STRIDE - 1)
|
||||
K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + tl.arange(0, HEAD_DIM)[:, None]
|
||||
|
||||
o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
|
||||
for iter in range(STRIDE):
|
||||
q = tl.load(Q_ptrs - iter * stride_qn)
|
||||
k = tl.load(K_ptrs + iter * stride_kn)
|
||||
o += tl.dot(q, k)
|
||||
|
||||
O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N
|
||||
O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :]
|
||||
|
||||
tl.store(O_ptrs, o.to(Out.type.element_ty))
|
||||
|
||||
|
||||
def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size, segment_size, chunk_start, chunk_end, real_q_len, scale, is_causal=True):
|
||||
"""Wrapper for Triton softmax-fuse-block-sum kernel."""
|
||||
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
||||
assert q_len % reshaped_block_size == 0
|
||||
assert k_len % segment_size == 0
|
||||
assert segment_size % reshaped_block_size == 0
|
||||
assert attn_weights_slice.stride(-1) == 1
|
||||
|
||||
output = torch.empty(
|
||||
(batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size),
|
||||
dtype=attn_weights_slice.dtype,
|
||||
device=attn_weights_slice.device
|
||||
)
|
||||
|
||||
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
||||
|
||||
if is_causal:
|
||||
softmax_fuse_block_sum_kernel_causal[grid](
|
||||
attn_weights_slice,
|
||||
output,
|
||||
scale,
|
||||
attn_weights_slice.stride(0),
|
||||
attn_weights_slice.stride(1),
|
||||
attn_weights_slice.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
real_q_len,
|
||||
k_len,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
segment_size,
|
||||
reshaped_block_size,
|
||||
)
|
||||
else:
|
||||
softmax_fuse_block_sum_kernel_non_causal[grid](
|
||||
attn_weights_slice,
|
||||
output,
|
||||
scale,
|
||||
attn_weights_slice.stride(0),
|
||||
attn_weights_slice.stride(1),
|
||||
attn_weights_slice.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
real_q_len,
|
||||
k_len,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
segment_size,
|
||||
reshaped_block_size,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def flat_group_gemm_fuse_reshape(query_states, key_states, stride, chunk_start, chunk_end, is_causal=True):
|
||||
"""Wrapper for Triton flat-group-gemm-fuse-reshape kernel."""
|
||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||
kv_len = key_states.shape[2]
|
||||
|
||||
assert key_states.shape[0] == batch_size
|
||||
assert key_states.shape[1] == num_heads
|
||||
assert key_states.shape[3] == head_dim
|
||||
|
||||
output = torch.empty(
|
||||
(batch_size, num_heads, q_len // stride, kv_len // stride),
|
||||
dtype=query_states.dtype,
|
||||
device=query_states.device
|
||||
)
|
||||
|
||||
# Adjust block size based on GPU shared memory
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
if props.total_memory < 30 * 1024**3: # Less than 30GB (e.g., RTX 3090 24GB)
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 64
|
||||
else:
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 128
|
||||
|
||||
assert q_len % (stride * BLOCK_M) == 0
|
||||
assert kv_len % (stride * BLOCK_N) == 0
|
||||
|
||||
grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads)
|
||||
flat_group_gemm_fuse_reshape_kernel[grid](
|
||||
query_states,
|
||||
key_states,
|
||||
output,
|
||||
query_states.stride(0),
|
||||
query_states.stride(1),
|
||||
query_states.stride(2),
|
||||
key_states.stride(0),
|
||||
key_states.stride(1),
|
||||
key_states.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
num_heads,
|
||||
stride,
|
||||
head_dim,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
is_causal,
|
||||
)
|
||||
|
||||
return output
|
||||
156
nanovllm/kvcache/sparse/utils.py
Normal file
156
nanovllm/kvcache/sparse/utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Utility functions for sparse attention policies.
|
||||
|
||||
Copied from COMPASS/compass/src/utils.py for XAttention integration.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def find_blocks_chunked(
|
||||
input_tensor, current_index, threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True
|
||||
):
|
||||
"""
|
||||
Finds and selects relevant blocks of attention for transformer-based models based on a
|
||||
threshold or a predefined number of blocks.
|
||||
|
||||
Parameters:
|
||||
- input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, chunk_num, block_num).
|
||||
- current_index (int): The current index in the sequence processing.
|
||||
- threshold (float or None): A threshold value used to determine the minimum attention weight sum.
|
||||
- num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval.
|
||||
- decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode.
|
||||
- mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'.
|
||||
- causal (bool): If True, applies causal masking to prevent future information leakage.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: A boolean mask of shape (batch_size, head_num, chunk_num, block_num),
|
||||
indicating which blocks should be attended to.
|
||||
"""
|
||||
assert threshold is None or num_to_choose is None
|
||||
batch_size, head_num, chunk_num, block_num = input_tensor.shape
|
||||
|
||||
if mode == "prefill" and decoding:
|
||||
return torch.ones_like(input_tensor, dtype=torch.bool)
|
||||
if mode == "decode" and not decoding:
|
||||
mask = torch.ones_like(input_tensor, dtype=torch.bool)
|
||||
if causal:
|
||||
mask[:, :, :, current_index : current_index + chunk_num] = torch.tril(
|
||||
torch.ones(1, head_num, chunk_num, chunk_num, device=input_tensor.device)
|
||||
)
|
||||
mask[:, :, current_index + chunk_num :, :] = 0
|
||||
return torch.cat(
|
||||
[
|
||||
torch.ones_like(input_tensor, dtype=torch.bool)[:, :, 0 : current_index + 1],
|
||||
torch.zeros_like(input_tensor, dtype=torch.bool)[:, :, current_index + 1 :],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
else:
|
||||
return mask
|
||||
|
||||
input_tensor = input_tensor.to(float)
|
||||
|
||||
if threshold is not None:
|
||||
total_sum = input_tensor.sum(dim=-1, keepdim=True)
|
||||
if isinstance(threshold, torch.Tensor):
|
||||
threshold = threshold.to(float)
|
||||
required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze(
|
||||
-1
|
||||
).expand((batch_size, head_num, chunk_num, 1)).to(input_tensor.device)
|
||||
else:
|
||||
required_sum = total_sum * threshold
|
||||
|
||||
if causal:
|
||||
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
||||
mask[:, :, :, 0] = 1
|
||||
mask[:, :, :, current_index : current_index + chunk_num] = (
|
||||
torch.eye(chunk_num, device=mask.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.expand(1, head_num, chunk_num, chunk_num)
|
||||
)
|
||||
other_values = input_tensor.masked_fill(mask, 0)
|
||||
sorted_values, _ = torch.sort(
|
||||
other_values, dim=-1, descending=True
|
||||
)
|
||||
sorted_values = sorted_values.to(input_tensor.device)
|
||||
|
||||
sorted_values = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
||||
),
|
||||
torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True),
|
||||
sorted_values[:, :, :, :-2],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
_, index = torch.sort(
|
||||
torch.where(mask, 100000 * (1 + input_tensor), input_tensor),
|
||||
dim=-1,
|
||||
descending=True
|
||||
)
|
||||
cumulative_sum_without_self = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
||||
),
|
||||
sorted_values[:, :, :, 0:-1],
|
||||
],
|
||||
dim=-1,
|
||||
).cumsum(dim=-1)
|
||||
|
||||
index_mask = cumulative_sum_without_self < required_sum
|
||||
index = torch.where(index_mask, index, 0)
|
||||
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
||||
index = index.view(batch_size, head_num * chunk_num, block_num)
|
||||
mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True
|
||||
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
||||
else:
|
||||
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
||||
sorted_values, index = torch.sort(
|
||||
input_tensor, dim=-1, descending=True
|
||||
)
|
||||
sorted_values = sorted_values.to(input_tensor.device)
|
||||
cumulative_sum_without_self = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
||||
),
|
||||
sorted_values[:, :, :, 0:-1],
|
||||
],
|
||||
dim=-1,
|
||||
).cumsum(dim=-1)
|
||||
index_mask = cumulative_sum_without_self < required_sum
|
||||
index = torch.where(index_mask, index, 0)
|
||||
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
||||
index = index.view(batch_size, head_num * chunk_num, block_num)
|
||||
mask[
|
||||
:,
|
||||
torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1),
|
||||
index,
|
||||
] = True
|
||||
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
||||
else:
|
||||
raise NotImplementedError("block num chunk prefill not implemented")
|
||||
|
||||
try:
|
||||
if causal:
|
||||
assert (~mask[:, :, :, current_index + chunk_num :]).all()
|
||||
except:
|
||||
mask[:, :, :, current_index + chunk_num :] = False
|
||||
|
||||
if causal:
|
||||
if decoding:
|
||||
assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all()
|
||||
else:
|
||||
lambda_mask = torch.zeros_like(input_tensor, dtype=bool, device=input_tensor.device)
|
||||
lambda_mask[:, :, :, 0] = 1
|
||||
lambda_mask[:, :, :, current_index:current_index+chunk_num] = torch.eye(
|
||||
chunk_num, device=lambda_mask.device
|
||||
).unsqueeze(0).unsqueeze(0).expand(1, head_num, chunk_num, chunk_num)
|
||||
assert(torch.where(lambda_mask, mask, True).all())
|
||||
|
||||
return mask
|
||||
464
nanovllm/kvcache/sparse/xattn.py
Normal file
464
nanovllm/kvcache/sparse/xattn.py
Normal file
@@ -0,0 +1,464 @@
|
||||
"""
|
||||
XAttention sparse attention policy for nano-vllm.
|
||||
|
||||
Implements the XAttention algorithm from COMPASS, using chunked estimation
|
||||
and block sparse attention for efficient long-context inference.
|
||||
|
||||
Reference: COMPASS/compass/src/Xattention.py
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.kvcache.sparse.kernels import (
|
||||
flat_group_gemm_fuse_reshape,
|
||||
softmax_fuse_block_sum,
|
||||
)
|
||||
from nanovllm.kvcache.sparse.utils import find_blocks_chunked
|
||||
|
||||
|
||||
class XAttentionPolicy(SparsePolicy):
|
||||
"""
|
||||
XAttention sparse prefill policy using chunked estimation + block sparse attention.
|
||||
|
||||
This policy estimates sparse attention patterns by:
|
||||
1. Chunked QK computation using Triton kernels
|
||||
2. Block-wise softmax with importance scores
|
||||
3. Block selection based on threshold
|
||||
4. Block sparse attention computation
|
||||
|
||||
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
|
||||
"""
|
||||
|
||||
supports_prefill = True
|
||||
supports_decode = False # XAttention is prefill-only
|
||||
requires_block_selection = False # Only affects attention computation
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stride: int = 8,
|
||||
threshold: float = 0.9,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_triton: bool = True,
|
||||
keep_sink: bool = False,
|
||||
keep_recent: bool = False,
|
||||
norm: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Initialize XAttention policy.
|
||||
|
||||
Args:
|
||||
stride: Stride for reorganizing Q/K (default: 8)
|
||||
threshold: Block selection threshold, 0-1 (default: 0.9)
|
||||
chunk_size: Chunk size for estimation (auto if None)
|
||||
use_triton: Use Triton kernels (requires SM 80+)
|
||||
keep_sink: Always keep first block (sink tokens)
|
||||
keep_recent: Always keep recent diagonal blocks
|
||||
norm: Normalization factor for attention scores
|
||||
"""
|
||||
self.stride = stride
|
||||
self.threshold = threshold
|
||||
self.chunk_size = chunk_size
|
||||
self.use_triton = use_triton
|
||||
self.keep_sink = keep_sink
|
||||
self.keep_recent = keep_recent
|
||||
self.norm = norm
|
||||
|
||||
# Check Triton availability
|
||||
if self.use_triton:
|
||||
try:
|
||||
import triton
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
if props.major < 8:
|
||||
self.use_triton = False
|
||||
print(f"XAttention: Triton requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.")
|
||||
except ImportError:
|
||||
self.use_triton = False
|
||||
print("XAttention: Triton not available. Falling back to PyTorch.")
|
||||
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select blocks for decode phase.
|
||||
|
||||
XAttention is prefill-only, so this method is only used as a fallback.
|
||||
Returns all available blocks by default.
|
||||
"""
|
||||
# XAttention is prefill-only, but we need to implement this abstract method
|
||||
# Since requires_block_selection=False, this won't be called for loading
|
||||
return available_blocks
|
||||
|
||||
def sparse_prefill_attention(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute XAttention sparse attention for prefill.
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||
layer_id: Current transformer layer index
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
seq_len = q.shape[0]
|
||||
num_heads = q.shape[1]
|
||||
head_dim = q.shape[2]
|
||||
num_kv_heads = k.shape[1]
|
||||
|
||||
# Use FlashAttention directly for CPU offload mode
|
||||
# FlashAttention supports GQA natively
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens,
|
||||
cu_seqlens_k=cu_seqlens,
|
||||
max_seqlen_q=seq_len,
|
||||
max_seqlen_k=seq_len,
|
||||
softmax_scale=1.0 / math.sqrt(head_dim),
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
except Exception as e:
|
||||
# Fallback: PyTorch SDPA (supports GQA natively)
|
||||
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=None,
|
||||
is_causal=True,
|
||||
scale=1.0 / math.sqrt(head_dim)
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def _xattn_offload_prefill(
|
||||
self,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
causal: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Simplified XAttention prefill for CPU offload mode.
|
||||
|
||||
Uses FlashAttention with full context since chunked estimation
|
||||
with full key_states requires special handling.
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||
_, _, k_len, _ = key_states.shape
|
||||
|
||||
# Use FlashAttention with full context
|
||||
# In offload mode, keys are already on CPU and loaded as needed
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
# Convert to [seq, heads, dim] format
|
||||
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
|
||||
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
|
||||
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=q_len,
|
||||
max_seqlen_k=k_len,
|
||||
softmax_scale=1.0 / math.sqrt(head_dim),
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
# Convert back to [batch, seq, heads, dim]
|
||||
attn_output = attn_output.unsqueeze(0).transpose(1, 2) # [1, q_len, num_heads, head_dim]
|
||||
|
||||
return attn_output
|
||||
|
||||
except Exception as e:
|
||||
# Final fallback: PyTorch SDPA
|
||||
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_states, key_states, value_states,
|
||||
attn_mask=None,
|
||||
is_causal=causal,
|
||||
scale=1.0 / math.sqrt(head_dim)
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def _xattn_prefill(
|
||||
self,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
stride: int,
|
||||
norm: float,
|
||||
threshold: float,
|
||||
block_size: int = 128,
|
||||
use_triton: bool = True,
|
||||
causal: bool = True,
|
||||
chunk_size: Optional[int] = None,
|
||||
keep_sink: bool = False,
|
||||
keep_recent: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
XAttention prefill implementation.
|
||||
|
||||
Args:
|
||||
query_states: [batch, num_heads, q_len, head_dim]
|
||||
key_states: [batch, num_heads, k_len, head_dim]
|
||||
value_states: [batch, num_heads, k_len, head_dim]
|
||||
... other params
|
||||
|
||||
Returns:
|
||||
Attention output [batch, q_len, num_heads, head_dim]
|
||||
"""
|
||||
batch_size, num_heads, k_len, head_dim = key_states.shape
|
||||
_, _, q_len, _ = query_states.shape
|
||||
|
||||
# Auto-compute chunk_size if not specified
|
||||
if chunk_size is None:
|
||||
chunk_size = int(
|
||||
max(
|
||||
min(
|
||||
max(2048, 1 << (k_len - 1).bit_length()),
|
||||
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()),
|
||||
),
|
||||
2048,
|
||||
)
|
||||
)
|
||||
|
||||
# Phase 1: Estimate sparse pattern
|
||||
attn_sums, approx_simple_mask = self._xattn_estimate(
|
||||
query_states,
|
||||
key_states,
|
||||
block_size=block_size,
|
||||
stride=stride,
|
||||
norm=norm,
|
||||
threshold=threshold,
|
||||
chunk_size=chunk_size,
|
||||
use_triton=use_triton,
|
||||
causal=causal,
|
||||
keep_sink=keep_sink,
|
||||
keep_recent=keep_recent,
|
||||
)
|
||||
|
||||
# Phase 2: Block sparse attention
|
||||
# For now, use FlashAttention as fallback since block_sparse_attn_func may not be available
|
||||
attn_output = self._block_sparse_attention_fallback(
|
||||
query_states, key_states, value_states,
|
||||
approx_simple_mask, block_size, q_len, k_len
|
||||
)
|
||||
|
||||
return attn_output
|
||||
|
||||
def _xattn_estimate(
|
||||
self,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
block_size: int,
|
||||
stride: int,
|
||||
norm: float = 1,
|
||||
softmax: bool = True,
|
||||
threshold: float = 0.9,
|
||||
chunk_size: int = 16384,
|
||||
use_triton: bool = True,
|
||||
causal: bool = True,
|
||||
keep_sink: bool = False,
|
||||
keep_recent: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Estimate sparse attention pattern using chunked computation.
|
||||
|
||||
Returns:
|
||||
attn_sums: [batch, heads, q_blocks, k_blocks] - importance scores
|
||||
simple_masks: [batch, heads, q_blocks, k_blocks] - boolean masks
|
||||
"""
|
||||
batch_size, num_kv_head, k_len, head_dim = key_states.shape
|
||||
batch_size, num_q_head, q_len, head_dim = query_states.shape
|
||||
|
||||
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
|
||||
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
|
||||
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
|
||||
k_block_num = (k_len + k_num_to_pad) // block_size
|
||||
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
|
||||
q_block_num = (q_len + q_num_to_pad) // block_size
|
||||
|
||||
# Pad inputs
|
||||
if k_num_to_pad > 0:
|
||||
pad_key_states = F.pad(key_states, (0, 0, 0, k_num_to_pad), value=0)
|
||||
else:
|
||||
pad_key_states = key_states
|
||||
if q_num_to_pad > 0:
|
||||
pad_query_states = F.pad(query_states, (0, 0, 0, q_num_to_pad), value=0)
|
||||
else:
|
||||
pad_query_states = query_states
|
||||
|
||||
reshaped_chunk_size = chunk_size // stride
|
||||
reshaped_block_size = block_size // stride
|
||||
k_reshaped_seq_len = (k_len + k_num_to_pad) // stride
|
||||
|
||||
attn_sum_list = []
|
||||
simple_mask_list = []
|
||||
|
||||
for chunk_idx in range(q_chunk_num):
|
||||
if use_triton:
|
||||
# Triton GEMM + Softmax
|
||||
attn_weights_slice = flat_group_gemm_fuse_reshape(
|
||||
pad_query_states[:, :, (chunk_idx * reshaped_chunk_size) * stride : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size) * stride, :],
|
||||
pad_key_states,
|
||||
stride,
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
|
||||
is_causal=causal,
|
||||
)
|
||||
|
||||
attn_sum = softmax_fuse_block_sum(
|
||||
attn_weights_slice,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
|
||||
k_reshaped_seq_len - (k_num_to_pad // stride),
|
||||
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
|
||||
is_causal=causal,
|
||||
)
|
||||
else:
|
||||
# PyTorch fallback
|
||||
chunk_size_actual = reshaped_chunk_size
|
||||
chunk_start = chunk_idx * chunk_size_actual
|
||||
chunk_end = chunk_start + chunk_size_actual
|
||||
|
||||
chunked_query = pad_query_states[:, :, chunk_start * stride:chunk_end * stride:stride, :]
|
||||
attn_weights_slice = torch.matmul(chunked_query, pad_key_states.transpose(2, 3))
|
||||
attn_weights_slice = attn_weights_slice / math.sqrt(head_dim) / stride / norm
|
||||
|
||||
if causal:
|
||||
causal_mask = torch.zeros((batch_size, num_q_head, chunk_size_actual, chunk_size_actual * k_chunk_num), device=key_states.device)
|
||||
causal_mask[:, :, :, -(k_num_to_pad // stride):] = float("-inf")
|
||||
# ... more causal mask logic ...
|
||||
attn_weights_slice = attn_weights_slice + causal_mask
|
||||
|
||||
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1, dtype=torch.float32)
|
||||
attn_sum = attn_weights_slice.view(batch_size, num_q_head, chunk_size_actual // reshaped_block_size, reshaped_block_size, -1).sum(dim=-1).sum(dim=-2)
|
||||
|
||||
# Find blocks based on threshold
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum,
|
||||
k_block_num - q_block_num + chunk_idx * (reshaped_chunk_size // reshaped_block_size),
|
||||
threshold,
|
||||
None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
attn_sum_list.append(attn_sum)
|
||||
simple_mask_list.append(simple_mask)
|
||||
|
||||
attn_sums = torch.cat(attn_sum_list, dim=-2)
|
||||
simple_masks = torch.cat(simple_mask_list, dim=-2)
|
||||
|
||||
# Apply causal mask to block masks
|
||||
if causal:
|
||||
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=key_states.device), diagonal=0),
|
||||
simple_masks[:, :, -q_block_num:, -q_block_num:],
|
||||
False,
|
||||
)
|
||||
|
||||
if keep_sink:
|
||||
simple_masks[:, :, 0, :] = True
|
||||
|
||||
if keep_recent:
|
||||
eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool)
|
||||
eye_matrix_expanded = eye_matrix.unsqueeze(0).unsqueeze(0).expand(1, num_q_head, q_block_num, q_block_num)
|
||||
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:]
|
||||
)
|
||||
|
||||
return attn_sums, simple_masks
|
||||
|
||||
def _block_sparse_attention_fallback(
|
||||
self,
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
block_size: int,
|
||||
q_len: int,
|
||||
k_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fallback implementation using FlashAttention.
|
||||
|
||||
Since block_sparse_attn_func may not be available in all environments,
|
||||
this uses standard FlashAttention with full attention.
|
||||
"""
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
|
||||
batch_size, num_heads, _, head_dim = query_states.shape
|
||||
|
||||
# Convert to [seq, heads, dim] format
|
||||
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
|
||||
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||
|
||||
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
||||
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=q_len,
|
||||
max_seqlen_k=k_len,
|
||||
softmax_scale=1.0 / math.sqrt(head_dim),
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Convert back to [batch, seq, heads, dim]
|
||||
attn_output = attn_output.unsqueeze(0).transpose(1, 2)
|
||||
|
||||
return attn_output
|
||||
|
||||
except Exception as e:
|
||||
# Final fallback: PyTorch SDPA
|
||||
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
||||
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
||||
attn_output = F.scaled_dot_product_attention(
|
||||
query_states, key_states, value_states,
|
||||
attn_mask=None,
|
||||
is_causal=True,
|
||||
scale=1.0 / math.sqrt(query_states.shape[-1])
|
||||
)
|
||||
return attn_output
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset policy state (no state to reset for XAttention)."""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f"XAttentionPolicy("
|
||||
f"stride={self.stride}, "
|
||||
f"threshold={self.threshold}, "
|
||||
f"use_triton={self.use_triton})")
|
||||
Reference in New Issue
Block a user