feat: add XAttention sparse policy integration
Integrate COMPASS XAttention algorithm into nano-vllm's CPU offload execution path. Uses FlashAttention with native GQA support for offload mode. New files: - nanovllm/kvcache/sparse/utils.py: find_blocks_chunked() utility - nanovllm/kvcache/sparse/kernels.py: Triton kernels for XAttention - nanovllm/kvcache/sparse/xattn.py: XAttentionPolicy implementation Modified: - nanovllm/config.py: Add XATTN configuration parameters - nanovllm/engine/model_runner.py: Support XATTN policy - nanovllm/kvcache/sparse/__init__.py: Register XAttentionPolicy - tests/test_ruler.py: Add --sparse-policy parameter Test results (32k ruler): - NIAH tasks: 12/12 (100%) - QA/Recall tasks: 11/15 (73%) - Overall: 23/27 (85%) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
156
nanovllm/kvcache/sparse/utils.py
Normal file
156
nanovllm/kvcache/sparse/utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Utility functions for sparse attention policies.
|
||||
|
||||
Copied from COMPASS/compass/src/utils.py for XAttention integration.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def find_blocks_chunked(
|
||||
input_tensor, current_index, threshold, num_to_choose, decoding: bool, mode: str = "both", causal=True
|
||||
):
|
||||
"""
|
||||
Finds and selects relevant blocks of attention for transformer-based models based on a
|
||||
threshold or a predefined number of blocks.
|
||||
|
||||
Parameters:
|
||||
- input_tensor (torch.Tensor): The input tensor of shape (batch_size, head_num, chunk_num, block_num).
|
||||
- current_index (int): The current index in the sequence processing.
|
||||
- threshold (float or None): A threshold value used to determine the minimum attention weight sum.
|
||||
- num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval.
|
||||
- decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode.
|
||||
- mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'.
|
||||
- causal (bool): If True, applies causal masking to prevent future information leakage.
|
||||
|
||||
Returns:
|
||||
- torch.Tensor: A boolean mask of shape (batch_size, head_num, chunk_num, block_num),
|
||||
indicating which blocks should be attended to.
|
||||
"""
|
||||
assert threshold is None or num_to_choose is None
|
||||
batch_size, head_num, chunk_num, block_num = input_tensor.shape
|
||||
|
||||
if mode == "prefill" and decoding:
|
||||
return torch.ones_like(input_tensor, dtype=torch.bool)
|
||||
if mode == "decode" and not decoding:
|
||||
mask = torch.ones_like(input_tensor, dtype=torch.bool)
|
||||
if causal:
|
||||
mask[:, :, :, current_index : current_index + chunk_num] = torch.tril(
|
||||
torch.ones(1, head_num, chunk_num, chunk_num, device=input_tensor.device)
|
||||
)
|
||||
mask[:, :, current_index + chunk_num :, :] = 0
|
||||
return torch.cat(
|
||||
[
|
||||
torch.ones_like(input_tensor, dtype=torch.bool)[:, :, 0 : current_index + 1],
|
||||
torch.zeros_like(input_tensor, dtype=torch.bool)[:, :, current_index + 1 :],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
else:
|
||||
return mask
|
||||
|
||||
input_tensor = input_tensor.to(float)
|
||||
|
||||
if threshold is not None:
|
||||
total_sum = input_tensor.sum(dim=-1, keepdim=True)
|
||||
if isinstance(threshold, torch.Tensor):
|
||||
threshold = threshold.to(float)
|
||||
required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze(
|
||||
-1
|
||||
).expand((batch_size, head_num, chunk_num, 1)).to(input_tensor.device)
|
||||
else:
|
||||
required_sum = total_sum * threshold
|
||||
|
||||
if causal:
|
||||
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
||||
mask[:, :, :, 0] = 1
|
||||
mask[:, :, :, current_index : current_index + chunk_num] = (
|
||||
torch.eye(chunk_num, device=mask.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.expand(1, head_num, chunk_num, chunk_num)
|
||||
)
|
||||
other_values = input_tensor.masked_fill(mask, 0)
|
||||
sorted_values, _ = torch.sort(
|
||||
other_values, dim=-1, descending=True
|
||||
)
|
||||
sorted_values = sorted_values.to(input_tensor.device)
|
||||
|
||||
sorted_values = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
||||
),
|
||||
torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True),
|
||||
sorted_values[:, :, :, :-2],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
_, index = torch.sort(
|
||||
torch.where(mask, 100000 * (1 + input_tensor), input_tensor),
|
||||
dim=-1,
|
||||
descending=True
|
||||
)
|
||||
cumulative_sum_without_self = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
||||
),
|
||||
sorted_values[:, :, :, 0:-1],
|
||||
],
|
||||
dim=-1,
|
||||
).cumsum(dim=-1)
|
||||
|
||||
index_mask = cumulative_sum_without_self < required_sum
|
||||
index = torch.where(index_mask, index, 0)
|
||||
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
||||
index = index.view(batch_size, head_num * chunk_num, block_num)
|
||||
mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True
|
||||
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
||||
else:
|
||||
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
||||
sorted_values, index = torch.sort(
|
||||
input_tensor, dim=-1, descending=True
|
||||
)
|
||||
sorted_values = sorted_values.to(input_tensor.device)
|
||||
cumulative_sum_without_self = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
(batch_size, head_num, chunk_num, 1), device=input_tensor.device
|
||||
),
|
||||
sorted_values[:, :, :, 0:-1],
|
||||
],
|
||||
dim=-1,
|
||||
).cumsum(dim=-1)
|
||||
index_mask = cumulative_sum_without_self < required_sum
|
||||
index = torch.where(index_mask, index, 0)
|
||||
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
||||
index = index.view(batch_size, head_num * chunk_num, block_num)
|
||||
mask[
|
||||
:,
|
||||
torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1),
|
||||
index,
|
||||
] = True
|
||||
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
||||
else:
|
||||
raise NotImplementedError("block num chunk prefill not implemented")
|
||||
|
||||
try:
|
||||
if causal:
|
||||
assert (~mask[:, :, :, current_index + chunk_num :]).all()
|
||||
except:
|
||||
mask[:, :, :, current_index + chunk_num :] = False
|
||||
|
||||
if causal:
|
||||
if decoding:
|
||||
assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all()
|
||||
else:
|
||||
lambda_mask = torch.zeros_like(input_tensor, dtype=bool, device=input_tensor.device)
|
||||
lambda_mask[:, :, :, 0] = 1
|
||||
lambda_mask[:, :, :, current_index:current_index+chunk_num] = torch.eye(
|
||||
chunk_num, device=lambda_mask.device
|
||||
).unsqueeze(0).unsqueeze(0).expand(1, head_num, chunk_num, chunk_num)
|
||||
assert(torch.where(lambda_mask, mask, True).all())
|
||||
|
||||
return mask
|
||||
Reference in New Issue
Block a user