📝 docs: add BSA interface documentation and cleanup temp files
- Add docs/block_sparse_attn_interface.md with BSA function signatures - Update CLAUDE.md documentation index - Remove obsolete DEBUG_SUMMARY.md and test_report_sparse_policy_refactor.md - Add notes.md to .gitignore Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
238
docs/block_sparse_attn_interface.md
Normal file
238
docs/block_sparse_attn_interface.md
Normal file
@@ -0,0 +1,238 @@
|
||||
# Block Sparse Attention Interface
|
||||
|
||||
Source: [MIT-HAN-LAB/Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention)
|
||||
|
||||
This document records the BSA (Block Sparse Attention) interface used by XAttention for sparse attention computation.
|
||||
|
||||
## Installation
|
||||
|
||||
BSA is installed in the `minference` conda environment:
|
||||
```
|
||||
/home/zijie/anaconda3/envs/minference/lib/python3.10/site-packages/block_sparse_attn/
|
||||
```
|
||||
|
||||
To use in other environments, add to PYTHONPATH:
|
||||
```bash
|
||||
PYTHONPATH=/home/zijie/anaconda3/envs/minference/lib/python3.10/site-packages:$PYTHONPATH python script.py
|
||||
```
|
||||
|
||||
## Interface Code
|
||||
|
||||
```python
|
||||
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_blocksparse_attn_interface.py
|
||||
|
||||
import block_sparse_attn_cuda
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def convert_blockmask(blockmask, causal):
|
||||
"""Convert from the 0-1 format to the format used by the CUDA code.
|
||||
0 means the block is skipped.
|
||||
nonzero means the block is not skipped.
|
||||
Argument:
|
||||
blockmask: (row, col): a 0-1 tensor
|
||||
Return:
|
||||
blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
|
||||
indices of the nonzero blocks, padded with -1 to reach length @row.
|
||||
The indices are multiplied by 4, with the smallest bit used to encode whether
|
||||
it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
|
||||
the last nonzero in its row..
|
||||
"""
|
||||
assert not causal
|
||||
nrow, ncol = blockmask.shape
|
||||
# Sort does not support bool on CUDA
|
||||
blockmask = blockmask.to(dtype=torch.uint8)
|
||||
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
|
||||
nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
|
||||
last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
|
||||
last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
|
||||
torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
|
||||
]
|
||||
first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
|
||||
first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
|
||||
torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
|
||||
]
|
||||
nonzero_idx = nonzero_sorted_rowidx * 4
|
||||
nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
|
||||
nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
|
||||
nonzero_idx[nonzero_val == 0] = -1
|
||||
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
|
||||
|
||||
|
||||
def convert_blockmask_row_reverse(blockmask, causal=False):
|
||||
blockmask = blockmask.to(dtype=torch.uint8)
|
||||
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-1, stable=True, descending=False)
|
||||
|
||||
nonzero_idx = nonzero_sorted_rowidx
|
||||
nonzero_idx[nonzero_val == 0] = -1
|
||||
nonzero_idx = torch.flip(nonzero_idx, dims=[-1])
|
||||
|
||||
return nonzero_idx.contiguous().to(dtype=torch.int32)
|
||||
|
||||
|
||||
def convert_blockmask_col_reverse(blockmask, causal=False):
|
||||
blockmask = blockmask.to(dtype=torch.uint8)
|
||||
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-2, stable=True, descending=False)
|
||||
|
||||
nonzero_idx = nonzero_sorted_rowidx
|
||||
nonzero_idx[nonzero_val == 0] = -1
|
||||
nonzero_idx = torch.flip(nonzero_idx, dims=[-2])
|
||||
nonzero_idx = torch.transpose(nonzero_idx, -1, -2)
|
||||
|
||||
return nonzero_idx.contiguous().to(dtype=torch.int32)
|
||||
|
||||
|
||||
def replace_ones_with_count(tensor):
|
||||
ones_mask = tensor == 1
|
||||
ones_num = ones_mask.sum()
|
||||
count = torch.cumsum(ones_mask, dim=-1).to(tensor.dtype)
|
||||
count = count * ones_mask
|
||||
tensor = tensor.masked_scatter(ones_mask, count[ones_mask])
|
||||
return tensor, ones_num
|
||||
|
||||
|
||||
def _block_sparse_attn_forward(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k,
|
||||
m_block_dim, n_block_dim,
|
||||
head_mask_type,
|
||||
streaming_info,
|
||||
row_blockmask,
|
||||
max_seqlen_q_, max_seqlen_k_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
exact_streaming,
|
||||
return_softmax,
|
||||
window_size_left,
|
||||
window_size_right
|
||||
):
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = block_sparse_attn_cuda.fwd_block(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k,
|
||||
m_block_dim, n_block_dim,
|
||||
head_mask_type,
|
||||
streaming_info,
|
||||
row_blockmask,
|
||||
max_seqlen_q_, max_seqlen_k_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
exact_streaming,
|
||||
return_softmax,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
None
|
||||
)
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
|
||||
|
||||
|
||||
def block_sparse_attn_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k,
|
||||
head_mask_type,
|
||||
streaming_info,
|
||||
base_blockmask,
|
||||
max_seqlen_q_, max_seqlen_k_,
|
||||
p_dropout,
|
||||
deterministic=False,
|
||||
softmax_scale=None,
|
||||
is_causal=False,
|
||||
exact_streaming=False,
|
||||
return_attn_probs=False,
|
||||
):
|
||||
"""
|
||||
Main entry point for block sparse attention.
|
||||
|
||||
Args:
|
||||
q: Query tensor [total_q, num_heads, head_dim]
|
||||
k: Key tensor [total_k, num_heads, head_dim]
|
||||
v: Value tensor [total_k, num_heads, head_dim]
|
||||
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
||||
cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
|
||||
head_mask_type: Per-head mask type [num_heads], 1 for block sparse
|
||||
streaming_info: Optional streaming attention info
|
||||
base_blockmask: Block mask [batch, num_heads, q_blocks, k_blocks]
|
||||
max_seqlen_q_: Maximum Q sequence length
|
||||
max_seqlen_k_: Maximum K sequence length
|
||||
p_dropout: Dropout probability (0.0 for eval)
|
||||
deterministic: Whether to use deterministic algorithms
|
||||
softmax_scale: Softmax scale (default: 1/sqrt(head_dim))
|
||||
is_causal: Whether to apply causal masking
|
||||
exact_streaming: Whether to use exact streaming attention
|
||||
return_attn_probs: Whether to return attention probabilities
|
||||
|
||||
Returns:
|
||||
Attention output [total_q, num_heads, head_dim]
|
||||
"""
|
||||
head_mask_type, blocksparse_head_num = replace_ones_with_count(head_mask_type)
|
||||
if base_blockmask is not None:
|
||||
assert base_blockmask.shape[1] == blocksparse_head_num
|
||||
|
||||
func = BlockSparseAttnFun if not return_attn_probs else BlockSparseAttnFunWithS
|
||||
return func.apply(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k,
|
||||
128, 128, # m_block_dim, n_block_dim (fixed at 128)
|
||||
head_mask_type,
|
||||
streaming_info,
|
||||
base_blockmask,
|
||||
max_seqlen_q_, max_seqlen_k_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
exact_streaming,
|
||||
return_attn_probs,
|
||||
-1, -1, # window_size_left, window_size_right
|
||||
deterministic
|
||||
)
|
||||
```
|
||||
|
||||
## Usage Example (from COMPASS)
|
||||
|
||||
```python
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
|
||||
# After xattn_estimate returns sparse mask
|
||||
attn_sums, approx_simple_mask = xattn_estimate(query_states, key_states, ...)
|
||||
|
||||
# Reshape for BSA (requires [seq_len, num_heads, head_dim] format)
|
||||
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
|
||||
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
||||
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
||||
|
||||
# Cumulative sequence lengths
|
||||
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
|
||||
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)
|
||||
|
||||
# Head mask type (1 for all heads using block sparse)
|
||||
head_mask_type = torch.tensor([1] * num_heads, device=device, dtype=torch.int32)
|
||||
|
||||
# Call BSA
|
||||
attn_output = block_sparse_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
q_cu_seq_lens,
|
||||
k_cu_seq_lens,
|
||||
head_mask_type,
|
||||
None, # streaming_info
|
||||
approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(),
|
||||
q_len,
|
||||
k_len,
|
||||
p_dropout=0.0,
|
||||
deterministic=True,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
# Reshape back to [batch, num_heads, seq_len, head_dim]
|
||||
attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
```
|
||||
|
||||
## Key Constraints
|
||||
|
||||
- **Block size**: Fixed at 128 tokens (hardcoded in BSA)
|
||||
- **Batch size**: Only batch_size=1 supported for block sparse mode
|
||||
- **Mask format**: `[batch, num_heads, q_blocks, k_blocks]` boolean tensor
|
||||
- **Input format**: `[total_seq_len, num_heads, head_dim]` (not batched)
|
||||
Reference in New Issue
Block a user