Integrate COMPASS XAttention algorithm into nano-vllm's CPU offload execution path. Uses FlashAttention with native GQA support for offload mode. New files: - nanovllm/kvcache/sparse/utils.py: find_blocks_chunked() utility - nanovllm/kvcache/sparse/kernels.py: Triton kernels for XAttention - nanovllm/kvcache/sparse/xattn.py: XAttentionPolicy implementation Modified: - nanovllm/config.py: Add XATTN configuration parameters - nanovllm/engine/model_runner.py: Support XATTN policy - nanovllm/kvcache/sparse/__init__.py: Register XAttentionPolicy - tests/test_ruler.py: Add --sparse-policy parameter Test results (32k ruler): - NIAH tasks: 12/12 (100%) - QA/Recall tasks: 11/15 (73%) - Overall: 23/27 (85%) Co-Authored-By: Claude <noreply@anthropic.com>
321 lines
10 KiB
Python
321 lines
10 KiB
Python
"""
|
|
Triton kernels for XAttention sparse attention.
|
|
|
|
Copied and adapted from COMPASS/compass/src/kernels.py
|
|
for XAttention integration in nano-vllm.
|
|
|
|
Requirements:
|
|
- Triton >= 2.1.0
|
|
- CUDA compute capability SM 80+ (RTX 3090, A100, H100, etc.)
|
|
"""
|
|
|
|
import torch
|
|
import math
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
|
|
@triton.jit
|
|
def softmax_fuse_block_sum_kernel_causal(
|
|
In,
|
|
Out,
|
|
scale,
|
|
input_stride_0,
|
|
input_stride_1,
|
|
input_stride_2,
|
|
output_stride_0,
|
|
output_stride_1,
|
|
output_stride_2,
|
|
real_q_len,
|
|
k_len,
|
|
chunk_start,
|
|
chunk_end,
|
|
segment_size: tl.constexpr,
|
|
block_size: tl.constexpr,
|
|
):
|
|
block_id = tl.program_id(0)
|
|
head_id = tl.program_id(1)
|
|
batch_id = tl.program_id(2)
|
|
|
|
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
|
offs_k = tl.arange(0, segment_size)
|
|
|
|
num_iters = k_len // segment_size
|
|
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size
|
|
|
|
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
|
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0
|
|
|
|
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
|
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
|
|
|
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
|
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
|
|
|
for iter in range(0, num_iters_before_causal):
|
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
m_local = tl.max(X, 1)
|
|
m_new = tl.maximum(m_i, m_local)
|
|
alpha = tl.math.exp2(m_i - m_new)
|
|
|
|
X = X - m_new[:, None]
|
|
l_local = tl.sum(tl.math.exp2(X), 1)
|
|
l_i = l_i * alpha + l_local
|
|
|
|
m_i = m_new
|
|
|
|
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
|
|
X = tl.where(mask, X, -1.0e6)
|
|
m_local = tl.max(X, 1)
|
|
m_new = tl.maximum(m_i, m_local)
|
|
alpha = tl.math.exp2(m_i - m_new)
|
|
|
|
X = X - m_new[:, None]
|
|
l_local = tl.sum(tl.math.exp2(X), 1)
|
|
l_i = l_i * alpha + l_local
|
|
|
|
m_i = m_new
|
|
|
|
l_i_inv = 1.0 / l_i
|
|
|
|
sum_mask = offs_q[:, None] < real_q_len
|
|
|
|
for iter in range(0, num_iters_before_causal):
|
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
|
X = tl.where(sum_mask, X, 0)
|
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
|
X = tl.sum(X, 2)
|
|
X = tl.sum(X, 0)
|
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
|
|
|
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
|
|
X = tl.where(mask, X, -1.0e6)
|
|
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
|
X = tl.where(sum_mask, X, 0)
|
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
|
X = tl.sum(X, 2)
|
|
X = tl.sum(X, 0)
|
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
|
|
|
for iter in range(num_iters_before_causal + 1, num_iters):
|
|
X = tl.zeros([segment_size // block_size], dtype=tl.float32)
|
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
|
|
|
|
|
@triton.jit
|
|
def softmax_fuse_block_sum_kernel_non_causal(
|
|
In,
|
|
Out,
|
|
scale,
|
|
input_stride_0,
|
|
input_stride_1,
|
|
input_stride_2,
|
|
output_stride_0,
|
|
output_stride_1,
|
|
output_stride_2,
|
|
real_q_len,
|
|
k_len,
|
|
chunk_start,
|
|
chunk_end,
|
|
segment_size: tl.constexpr,
|
|
block_size: tl.constexpr,
|
|
):
|
|
block_id = tl.program_id(0)
|
|
head_id = tl.program_id(1)
|
|
batch_id = tl.program_id(2)
|
|
|
|
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
|
offs_k = tl.arange(0, segment_size)
|
|
|
|
num_iters = k_len // segment_size
|
|
|
|
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
|
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0
|
|
|
|
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
|
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
|
|
|
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
|
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
|
|
|
for iter in range(0, num_iters):
|
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
m_local = tl.max(X, 1)
|
|
m_new = tl.maximum(m_i, m_local)
|
|
alpha = tl.math.exp2(m_i - m_new)
|
|
|
|
X = X - m_new[:, None]
|
|
l_local = tl.sum(tl.math.exp2(X), 1)
|
|
l_i = l_i * alpha + l_local
|
|
|
|
m_i = m_new
|
|
|
|
l_i_inv = 1.0 / l_i
|
|
|
|
sum_mask = offs_q[:, None] < real_q_len
|
|
|
|
for iter in range(0, num_iters):
|
|
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
|
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
|
X = tl.where(sum_mask, X, 0)
|
|
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
|
X = tl.sum(X, 2)
|
|
X = tl.sum(X, 0)
|
|
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
|
|
|
|
|
@triton.jit
|
|
def flat_group_gemm_fuse_reshape_kernel(Q, K, Out,
|
|
stride_qz, stride_qh, stride_qn,
|
|
stride_kz, stride_kh, stride_kn,
|
|
stride_oz, stride_oh, stride_on,
|
|
chunk_start, chunk_end,
|
|
H: tl.constexpr,
|
|
STRIDE: tl.constexpr,
|
|
HEAD_DIM: tl.constexpr,
|
|
BLOCK_M: tl.constexpr,
|
|
BLOCK_N: tl.constexpr,
|
|
is_causal: tl.constexpr,
|
|
):
|
|
block_m = tl.program_id(0).to(tl.int64)
|
|
block_n = tl.program_id(1).to(tl.int64)
|
|
batch_id = tl.program_id(2).to(tl.int64) // H
|
|
head_id = tl.program_id(2).to(tl.int64) % H
|
|
|
|
if is_causal:
|
|
if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N:
|
|
return
|
|
|
|
Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn
|
|
K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn
|
|
|
|
Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + tl.arange(0, HEAD_DIM)[None, :] + stride_qn * (STRIDE - 1)
|
|
K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + tl.arange(0, HEAD_DIM)[:, None]
|
|
|
|
o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
|
|
for iter in range(STRIDE):
|
|
q = tl.load(Q_ptrs - iter * stride_qn)
|
|
k = tl.load(K_ptrs + iter * stride_kn)
|
|
o += tl.dot(q, k)
|
|
|
|
O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N
|
|
O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :]
|
|
|
|
tl.store(O_ptrs, o.to(Out.type.element_ty))
|
|
|
|
|
|
def softmax_fuse_block_sum(attn_weights_slice, reshaped_block_size, segment_size, chunk_start, chunk_end, real_q_len, scale, is_causal=True):
|
|
"""Wrapper for Triton softmax-fuse-block-sum kernel."""
|
|
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
|
assert q_len % reshaped_block_size == 0
|
|
assert k_len % segment_size == 0
|
|
assert segment_size % reshaped_block_size == 0
|
|
assert attn_weights_slice.stride(-1) == 1
|
|
|
|
output = torch.empty(
|
|
(batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size),
|
|
dtype=attn_weights_slice.dtype,
|
|
device=attn_weights_slice.device
|
|
)
|
|
|
|
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
|
|
|
if is_causal:
|
|
softmax_fuse_block_sum_kernel_causal[grid](
|
|
attn_weights_slice,
|
|
output,
|
|
scale,
|
|
attn_weights_slice.stride(0),
|
|
attn_weights_slice.stride(1),
|
|
attn_weights_slice.stride(2),
|
|
output.stride(0),
|
|
output.stride(1),
|
|
output.stride(2),
|
|
real_q_len,
|
|
k_len,
|
|
chunk_start,
|
|
chunk_end,
|
|
segment_size,
|
|
reshaped_block_size,
|
|
)
|
|
else:
|
|
softmax_fuse_block_sum_kernel_non_causal[grid](
|
|
attn_weights_slice,
|
|
output,
|
|
scale,
|
|
attn_weights_slice.stride(0),
|
|
attn_weights_slice.stride(1),
|
|
attn_weights_slice.stride(2),
|
|
output.stride(0),
|
|
output.stride(1),
|
|
output.stride(2),
|
|
real_q_len,
|
|
k_len,
|
|
chunk_start,
|
|
chunk_end,
|
|
segment_size,
|
|
reshaped_block_size,
|
|
)
|
|
|
|
return output
|
|
|
|
|
|
def flat_group_gemm_fuse_reshape(query_states, key_states, stride, chunk_start, chunk_end, is_causal=True):
|
|
"""Wrapper for Triton flat-group-gemm-fuse-reshape kernel."""
|
|
batch_size, num_heads, q_len, head_dim = query_states.shape
|
|
kv_len = key_states.shape[2]
|
|
|
|
assert key_states.shape[0] == batch_size
|
|
assert key_states.shape[1] == num_heads
|
|
assert key_states.shape[3] == head_dim
|
|
|
|
output = torch.empty(
|
|
(batch_size, num_heads, q_len // stride, kv_len // stride),
|
|
dtype=query_states.dtype,
|
|
device=query_states.device
|
|
)
|
|
|
|
# Adjust block size based on GPU shared memory
|
|
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
|
if props.total_memory < 30 * 1024**3: # Less than 30GB (e.g., RTX 3090 24GB)
|
|
BLOCK_M = 64
|
|
BLOCK_N = 64
|
|
else:
|
|
BLOCK_M = 128
|
|
BLOCK_N = 128
|
|
|
|
assert q_len % (stride * BLOCK_M) == 0
|
|
assert kv_len % (stride * BLOCK_N) == 0
|
|
|
|
grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads)
|
|
flat_group_gemm_fuse_reshape_kernel[grid](
|
|
query_states,
|
|
key_states,
|
|
output,
|
|
query_states.stride(0),
|
|
query_states.stride(1),
|
|
query_states.stride(2),
|
|
key_states.stride(0),
|
|
key_states.stride(1),
|
|
key_states.stride(2),
|
|
output.stride(0),
|
|
output.stride(1),
|
|
output.stride(2),
|
|
chunk_start,
|
|
chunk_end,
|
|
num_heads,
|
|
stride,
|
|
head_dim,
|
|
BLOCK_M,
|
|
BLOCK_N,
|
|
is_causal,
|
|
)
|
|
|
|
return output
|