Integrate COMPASS XAttention algorithm into nano-vllm's CPU offload execution path. Uses FlashAttention with native GQA support for offload mode. New files: - nanovllm/kvcache/sparse/utils.py: find_blocks_chunked() utility - nanovllm/kvcache/sparse/kernels.py: Triton kernels for XAttention - nanovllm/kvcache/sparse/xattn.py: XAttentionPolicy implementation Modified: - nanovllm/config.py: Add XATTN configuration parameters - nanovllm/engine/model_runner.py: Support XATTN policy - nanovllm/kvcache/sparse/__init__.py: Register XAttentionPolicy - tests/test_ruler.py: Add --sparse-policy parameter Test results (32k ruler): - NIAH tasks: 12/12 (100%) - QA/Recall tasks: 11/15 (73%) - Overall: 23/27 (85%) Co-Authored-By: Claude <noreply@anthropic.com>
157 lines
6.4 KiB
Python
157 lines
6.4 KiB
Python
"""
|
|
Utility functions for sparse attention policies.
|
|
|
|
Copied from COMPASS/compass/src/utils.py for XAttention integration.
|
|
"""
|
|
|
|
import torch
|
|
|
|
|
|
def find_blocks_chunked(
|
|
input_tensor, current_index, threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True
|
|
):
|
|
"""
|
|
Finds and selects relevant blocks of attention for transformer-based models based on a
|
|
threshold or a predefined number of blocks.
|
|
|
|
Parameters:
|
|
- input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, chunk_num, block_num).
|
|
- current_index (int): The current index in the sequence processing.
|
|
- threshold (float or None): A threshold value used to determine the minimum attention weight sum.
|
|
- num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval.
|
|
- decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode.
|
|
- mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'.
|
|
- causal (bool): If True, applies causal masking to prevent future information leakage.
|
|
|
|
Returns:
|
|
- torch.Tensor: A boolean mask of shape (batch_size, head_num, chunk_num, block_num),
|
|
indicating which blocks should be attended to.
|
|
"""
|
|
assert threshold is None or num_to_choose is None
|
|
batch_size, head_num, chunk_num, block_num = input_tensor.shape
|
|
|
|
if mode == "prefill" and decoding:
|
|
return torch.ones_like(input_tensor, dtype=torch.bool)
|
|
if mode == "decode" and not decoding:
|
|
mask = torch.ones_like(input_tensor, dtype=torch.bool)
|
|
if causal:
|
|
mask[:, :, :, current_index : current_index + chunk_num] = torch.tril(
|
|
torch.ones(1, head_num, chunk_num, chunk_num, device=input_tensor.device)
|
|
)
|
|
mask[:, :, current_index + chunk_num :, :] = 0
|
|
return torch.cat(
|
|
[
|
|
torch.ones_like(input_tensor, dtype=torch.bool)[:, :, 0 : current_index + 1],
|
|
torch.zeros_like(input_tensor, dtype=torch.bool)[:, :, current_index + 1 :],
|
|
],
|
|
dim=-1,
|
|
)
|
|
else:
|
|
return mask
|
|
|
|
input_tensor = input_tensor.to(float)
|
|
|
|
if threshold is not None:
|
|
total_sum = input_tensor.sum(dim=-1, keepdim=True)
|
|
if isinstance(threshold, torch.Tensor):
|
|
threshold = threshold.to(float)
|
|
required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze(
|
|
-1
|
|
).expand((batch_size, head_num, chunk_num, 1)).to(input_tensor.device)
|
|
else:
|
|
required_sum = total_sum * threshold
|
|
|
|
if causal:
|
|
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
|
mask[:, :, :, 0] = 1
|
|
mask[:, :, :, current_index : current_index + chunk_num] = (
|
|
torch.eye(chunk_num, device=mask.device)
|
|
.unsqueeze(0)
|
|
.unsqueeze(0)
|
|
.expand(1, head_num, chunk_num, chunk_num)
|
|
)
|
|
other_values = input_tensor.masked_fill(mask, 0)
|
|
sorted_values, _ = torch.sort(
|
|
other_values, dim=-1, descending=True
|
|
)
|
|
sorted_values = sorted_values.to(input_tensor.device)
|
|
|
|
sorted_values = torch.cat(
|
|
[
|
|
torch.zeros(
|
|
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
|
),
|
|
torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True),
|
|
sorted_values[:, :, :, :-2],
|
|
],
|
|
dim=-1,
|
|
)
|
|
|
|
_, index = torch.sort(
|
|
torch.where(mask, 100000 * (1 + input_tensor), input_tensor),
|
|
dim=-1,
|
|
descending=True
|
|
)
|
|
cumulative_sum_without_self = torch.cat(
|
|
[
|
|
torch.zeros(
|
|
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
|
),
|
|
sorted_values[:, :, :, 0:-1],
|
|
],
|
|
dim=-1,
|
|
).cumsum(dim=-1)
|
|
|
|
index_mask = cumulative_sum_without_self < required_sum
|
|
index = torch.where(index_mask, index, 0)
|
|
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
|
index = index.view(batch_size, head_num * chunk_num, block_num)
|
|
mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True
|
|
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
|
else:
|
|
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
|
sorted_values, index = torch.sort(
|
|
input_tensor, dim=-1, descending=True
|
|
)
|
|
sorted_values = sorted_values.to(input_tensor.device)
|
|
cumulative_sum_without_self = torch.cat(
|
|
[
|
|
torch.zeros(
|
|
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
|
),
|
|
sorted_values[:, :, :, 0:-1],
|
|
],
|
|
dim=-1,
|
|
).cumsum(dim=-1)
|
|
index_mask = cumulative_sum_without_self < required_sum
|
|
index = torch.where(index_mask, index, 0)
|
|
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
|
index = index.view(batch_size, head_num * chunk_num, block_num)
|
|
mask[
|
|
:,
|
|
torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1),
|
|
index,
|
|
] = True
|
|
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
|
else:
|
|
raise NotImplementedError("block num chunk prefill not implemented")
|
|
|
|
try:
|
|
if causal:
|
|
assert (~mask[:, :, :, current_index + chunk_num :]).all()
|
|
except:
|
|
mask[:, :, :, current_index + chunk_num :] = False
|
|
|
|
if causal:
|
|
if decoding:
|
|
assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all()
|
|
else:
|
|
lambda_mask = torch.zeros_like(input_tensor, dtype=bool, device=input_tensor.device)
|
|
lambda_mask[:, :, :, 0] = 1
|
|
lambda_mask[:, :, :, current_index:current_index+chunk_num] = torch.eye(
|
|
chunk_num, device=lambda_mask.device
|
|
).unsqueeze(0).unsqueeze(0).expand(1, head_num, chunk_num, chunk_num)
|
|
assert(torch.where(lambda_mask, mask, True).all())
|
|
|
|
return mask
|