Files
nano-vllm/nanovllm/kvcache/sparse/utils.py
Zijie Tian ac1ccbceaa feat: add XAttention sparse policy integration
Integrate COMPASS XAttention algorithm into nano-vllm's CPU offload
execution path. Uses FlashAttention with native GQA support for
offload mode.

New files:
- nanovllm/kvcache/sparse/utils.py: find_blocks_chunked() utility
- nanovllm/kvcache/sparse/kernels.py: Triton kernels for XAttention
- nanovllm/kvcache/sparse/xattn.py: XAttentionPolicy implementation

Modified:
- nanovllm/config.py: Add XATTN configuration parameters
- nanovllm/engine/model_runner.py: Support XATTN policy
- nanovllm/kvcache/sparse/__init__.py: Register XAttentionPolicy
- tests/test_ruler.py: Add --sparse-policy parameter

Test results (32k ruler):
- NIAH tasks: 12/12 (100%)
- QA/Recall tasks: 11/15 (73%)
- Overall: 23/27 (85%)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-14 10:04:46 +08:00

157 lines
6.4 KiB
Python

"""
Utility functions for sparse attention policies.
Copied from COMPASS/compass/src/utils.py for XAttention integration.
"""
import torch
def find_blocks_chunked(
input_tensor, current_index, threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True
):
"""
Finds and selects relevant blocks of attention for transformer-based models based on a
threshold or a predefined number of blocks.
Parameters:
- input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, chunk_num, block_num).
- current_index (int): The current index in the sequence processing.
- threshold (float or None): A threshold value used to determine the minimum attention weight sum.
- num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval.
- decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode.
- mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'.
- causal (bool): If True, applies causal masking to prevent future information leakage.
Returns:
- torch.Tensor: A boolean mask of shape (batch_size, head_num, chunk_num, block_num),
indicating which blocks should be attended to.
"""
assert threshold is None or num_to_choose is None
batch_size, head_num, chunk_num, block_num = input_tensor.shape
if mode == "prefill" and decoding:
return torch.ones_like(input_tensor, dtype=torch.bool)
if mode == "decode" and not decoding:
mask = torch.ones_like(input_tensor, dtype=torch.bool)
if causal:
mask[:, :, :, current_index : current_index + chunk_num] = torch.tril(
torch.ones(1, head_num, chunk_num, chunk_num, device=input_tensor.device)
)
mask[:, :, current_index + chunk_num :, :] = 0
return torch.cat(
[
torch.ones_like(input_tensor, dtype=torch.bool)[:, :, 0 : current_index + 1],
torch.zeros_like(input_tensor, dtype=torch.bool)[:, :, current_index + 1 :],
],
dim=-1,
)
else:
return mask
input_tensor = input_tensor.to(float)
if threshold is not None:
total_sum = input_tensor.sum(dim=-1, keepdim=True)
if isinstance(threshold, torch.Tensor):
threshold = threshold.to(float)
required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze(
-1
).expand((batch_size, head_num, chunk_num, 1)).to(input_tensor.device)
else:
required_sum = total_sum * threshold
if causal:
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
mask[:, :, :, 0] = 1
mask[:, :, :, current_index : current_index + chunk_num] = (
torch.eye(chunk_num, device=mask.device)
.unsqueeze(0)
.unsqueeze(0)
.expand(1, head_num, chunk_num, chunk_num)
)
other_values = input_tensor.masked_fill(mask, 0)
sorted_values, _ = torch.sort(
other_values, dim=-1, descending=True
)
sorted_values = sorted_values.to(input_tensor.device)
sorted_values = torch.cat(
[
torch.zeros(
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
),
torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True),
sorted_values[:, :, :, :-2],
],
dim=-1,
)
_, index = torch.sort(
torch.where(mask, 100000 * (1 + input_tensor), input_tensor),
dim=-1,
descending=True
)
cumulative_sum_without_self = torch.cat(
[
torch.zeros(
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
),
sorted_values[:, :, :, 0:-1],
],
dim=-1,
).cumsum(dim=-1)
index_mask = cumulative_sum_without_self < required_sum
index = torch.where(index_mask, index, 0)
mask = mask.view(batch_size, head_num * chunk_num, block_num)
index = index.view(batch_size, head_num * chunk_num, block_num)
mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True
mask = mask.view(batch_size, head_num, chunk_num, block_num)
else:
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
sorted_values, index = torch.sort(
input_tensor, dim=-1, descending=True
)
sorted_values = sorted_values.to(input_tensor.device)
cumulative_sum_without_self = torch.cat(
[
torch.zeros(
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
),
sorted_values[:, :, :, 0:-1],
],
dim=-1,
).cumsum(dim=-1)
index_mask = cumulative_sum_without_self < required_sum
index = torch.where(index_mask, index, 0)
mask = mask.view(batch_size, head_num * chunk_num, block_num)
index = index.view(batch_size, head_num * chunk_num, block_num)
mask[
:,
torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1),
index,
] = True
mask = mask.view(batch_size, head_num, chunk_num, block_num)
else:
raise NotImplementedError("block num chunk prefill not implemented")
try:
if causal:
assert (~mask[:, :, :, current_index + chunk_num :]).all()
except:
mask[:, :, :, current_index + chunk_num :] = False
if causal:
if decoding:
assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all()
else:
lambda_mask = torch.zeros_like(input_tensor, dtype=bool, device=input_tensor.device)
lambda_mask[:, :, :, 0] = 1
lambda_mask[:, :, :, current_index:current_index+chunk_num] = torch.eye(
chunk_num, device=lambda_mask.device
).unsqueeze(0).unsqueeze(0).expand(1, head_num, chunk_num, chunk_num)
assert(torch.where(lambda_mask, mask, True).all())
return mask