- Add docs/block_sparse_attn_interface.md with BSA function signatures - Update CLAUDE.md documentation index - Remove obsolete DEBUG_SUMMARY.md and test_report_sparse_policy_refactor.md - Add notes.md to .gitignore Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
8.3 KiB
8.3 KiB
Block Sparse Attention Interface
Source: MIT-HAN-LAB/Block-Sparse-Attention
This document records the BSA (Block Sparse Attention) interface used by XAttention for sparse attention computation.
Installation
BSA is installed in the minference conda environment:
/home/zijie/anaconda3/envs/minference/lib/python3.10/site-packages/block_sparse_attn/
To use in other environments, add to PYTHONPATH:
PYTHONPATH=/home/zijie/anaconda3/envs/minference/lib/python3.10/site-packages:$PYTHONPATH python script.py
Interface Code
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_blocksparse_attn_interface.py
import block_sparse_attn_cuda
import torch
import torch.nn as nn
def convert_blockmask(blockmask, causal):
"""Convert from the 0-1 format to the format used by the CUDA code.
0 means the block is skipped.
nonzero means the block is not skipped.
Argument:
blockmask: (row, col): a 0-1 tensor
Return:
blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
indices of the nonzero blocks, padded with -1 to reach length @row.
The indices are multiplied by 4, with the smallest bit used to encode whether
it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
the last nonzero in its row..
"""
assert not causal
nrow, ncol = blockmask.shape
# Sort does not support bool on CUDA
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
]
first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
]
nonzero_idx = nonzero_sorted_rowidx * 4
nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
nonzero_idx[nonzero_val == 0] = -1
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
def convert_blockmask_row_reverse(blockmask, causal=False):
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-1, stable=True, descending=False)
nonzero_idx = nonzero_sorted_rowidx
nonzero_idx[nonzero_val == 0] = -1
nonzero_idx = torch.flip(nonzero_idx, dims=[-1])
return nonzero_idx.contiguous().to(dtype=torch.int32)
def convert_blockmask_col_reverse(blockmask, causal=False):
blockmask = blockmask.to(dtype=torch.uint8)
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-2, stable=True, descending=False)
nonzero_idx = nonzero_sorted_rowidx
nonzero_idx[nonzero_val == 0] = -1
nonzero_idx = torch.flip(nonzero_idx, dims=[-2])
nonzero_idx = torch.transpose(nonzero_idx, -1, -2)
return nonzero_idx.contiguous().to(dtype=torch.int32)
def replace_ones_with_count(tensor):
ones_mask = tensor == 1
ones_num = ones_mask.sum()
count = torch.cumsum(ones_mask, dim=-1).to(tensor.dtype)
count = count * ones_mask
tensor = tensor.masked_scatter(ones_mask, count[ones_mask])
return tensor, ones_num
def _block_sparse_attn_forward(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
row_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax,
window_size_left,
window_size_right
):
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = block_sparse_attn_cuda.fwd_block(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
m_block_dim, n_block_dim,
head_mask_type,
streaming_info,
row_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_softmax,
window_size_left,
window_size_right,
None
)
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
def block_sparse_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=False,
exact_streaming=False,
return_attn_probs=False,
):
"""
Main entry point for block sparse attention.
Args:
q: Query tensor [total_q, num_heads, head_dim]
k: Key tensor [total_k, num_heads, head_dim]
v: Value tensor [total_k, num_heads, head_dim]
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
head_mask_type: Per-head mask type [num_heads], 1 for block sparse
streaming_info: Optional streaming attention info
base_blockmask: Block mask [batch, num_heads, q_blocks, k_blocks]
max_seqlen_q_: Maximum Q sequence length
max_seqlen_k_: Maximum K sequence length
p_dropout: Dropout probability (0.0 for eval)
deterministic: Whether to use deterministic algorithms
softmax_scale: Softmax scale (default: 1/sqrt(head_dim))
is_causal: Whether to apply causal masking
exact_streaming: Whether to use exact streaming attention
return_attn_probs: Whether to return attention probabilities
Returns:
Attention output [total_q, num_heads, head_dim]
"""
head_mask_type, blocksparse_head_num = replace_ones_with_count(head_mask_type)
if base_blockmask is not None:
assert base_blockmask.shape[1] == blocksparse_head_num
func = BlockSparseAttnFun if not return_attn_probs else BlockSparseAttnFunWithS
return func.apply(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
128, 128, # m_block_dim, n_block_dim (fixed at 128)
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
softmax_scale,
is_causal,
exact_streaming,
return_attn_probs,
-1, -1, # window_size_left, window_size_right
deterministic
)
Usage Example (from COMPASS)
from block_sparse_attn import block_sparse_attn_func
# After xattn_estimate returns sparse mask
attn_sums, approx_simple_mask = xattn_estimate(query_states, key_states, ...)
# Reshape for BSA (requires [seq_len, num_heads, head_dim] format)
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)
# Cumulative sequence lengths
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)
# Head mask type (1 for all heads using block sparse)
head_mask_type = torch.tensor([1] * num_heads, device=device, dtype=torch.int32)
# Call BSA
attn_output = block_sparse_attn_func(
query_states,
key_states,
value_states,
q_cu_seq_lens,
k_cu_seq_lens,
head_mask_type,
None, # streaming_info
approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(),
q_len,
k_len,
p_dropout=0.0,
deterministic=True,
is_causal=True,
)
# Reshape back to [batch, num_heads, seq_len, head_dim]
attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
Key Constraints
- Block size: Fixed at 128 tokens (hardcoded in BSA)
- Batch size: Only batch_size=1 supported for block sparse mode
- Mask format:
[batch, num_heads, q_blocks, k_blocks]boolean tensor - Input format:
[total_seq_len, num_heads, head_dim](not batched)