Files
nano-vllm/docs/block_sparse_attn_interface.md
Zijie Tian 4cbd451af7 📝 docs: add BSA interface documentation and cleanup temp files
- Add docs/block_sparse_attn_interface.md with BSA function signatures
- Update CLAUDE.md documentation index
- Remove obsolete DEBUG_SUMMARY.md and test_report_sparse_policy_refactor.md
- Add notes.md to .gitignore

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-20 04:27:19 +08:00

8.3 KiB

Block Sparse Attention Interface

Source: MIT-HAN-LAB/Block-Sparse-Attention

This document records the BSA (Block Sparse Attention) interface used by XAttention for sparse attention computation.

Installation

BSA is installed in the minference conda environment:

/home/zijie/anaconda3/envs/minference/lib/python3.10/site-packages/block_sparse_attn/

To use in other environments, add to PYTHONPATH:

PYTHONPATH=/home/zijie/anaconda3/envs/minference/lib/python3.10/site-packages:$PYTHONPATH python script.py

Interface Code

# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_blocksparse_attn_interface.py

import block_sparse_attn_cuda
import torch
import torch.nn as nn


def convert_blockmask(blockmask, causal):
    """Convert from the 0-1 format to the format used by the CUDA code.
    0 means the block is skipped.
    nonzero means the block is not skipped.
    Argument:
        blockmask: (row, col): a 0-1 tensor
    Return:
        blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
            indices of the nonzero blocks, padded with -1 to reach length @row.
            The indices are multiplied by 4, with the smallest bit used to encode whether
            it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
            the last nonzero in its row..
    """
    assert not causal
    nrow, ncol = blockmask.shape
    # Sort does not support bool on CUDA
    blockmask = blockmask.to(dtype=torch.uint8)
    nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
    nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
    last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
    last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
        torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
    ]
    first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
    first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
        torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
    ]
    nonzero_idx = nonzero_sorted_rowidx * 4
    nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
    nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
    nonzero_idx[nonzero_val == 0] = -1
    return nonzero_idx.T.contiguous().to(dtype=torch.int32)


def convert_blockmask_row_reverse(blockmask, causal=False):
    blockmask = blockmask.to(dtype=torch.uint8)
    nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-1, stable=True, descending=False)

    nonzero_idx = nonzero_sorted_rowidx
    nonzero_idx[nonzero_val == 0] = -1
    nonzero_idx = torch.flip(nonzero_idx, dims=[-1])

    return nonzero_idx.contiguous().to(dtype=torch.int32)


def convert_blockmask_col_reverse(blockmask, causal=False):
    blockmask = blockmask.to(dtype=torch.uint8)
    nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-2, stable=True, descending=False)

    nonzero_idx = nonzero_sorted_rowidx
    nonzero_idx[nonzero_val == 0] = -1
    nonzero_idx = torch.flip(nonzero_idx, dims=[-2])
    nonzero_idx = torch.transpose(nonzero_idx, -1, -2)

    return nonzero_idx.contiguous().to(dtype=torch.int32)


def replace_ones_with_count(tensor):
    ones_mask = tensor == 1
    ones_num = ones_mask.sum()
    count = torch.cumsum(ones_mask, dim=-1).to(tensor.dtype)
    count = count * ones_mask
    tensor = tensor.masked_scatter(ones_mask, count[ones_mask])
    return tensor, ones_num


def _block_sparse_attn_forward(
    q, k, v,
    cu_seqlens_q, cu_seqlens_k,
    m_block_dim, n_block_dim,
    head_mask_type,
    streaming_info,
    row_blockmask,
    max_seqlen_q_, max_seqlen_k_,
    p_dropout,
    softmax_scale,
    is_causal,
    exact_streaming,
    return_softmax,
    window_size_left,
    window_size_right
):
    out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = block_sparse_attn_cuda.fwd_block(
        q, k, v,
        cu_seqlens_q, cu_seqlens_k,
        m_block_dim, n_block_dim,
        head_mask_type,
        streaming_info,
        row_blockmask,
        max_seqlen_q_, max_seqlen_k_,
        p_dropout,
        softmax_scale,
        is_causal,
        exact_streaming,
        return_softmax,
        window_size_left,
        window_size_right,
        None
    )
    return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state


def block_sparse_attn_func(
    q, k, v,
    cu_seqlens_q, cu_seqlens_k,
    head_mask_type,
    streaming_info,
    base_blockmask,
    max_seqlen_q_, max_seqlen_k_,
    p_dropout,
    deterministic=False,
    softmax_scale=None,
    is_causal=False,
    exact_streaming=False,
    return_attn_probs=False,
):
    """
    Main entry point for block sparse attention.

    Args:
        q: Query tensor [total_q, num_heads, head_dim]
        k: Key tensor [total_k, num_heads, head_dim]
        v: Value tensor [total_k, num_heads, head_dim]
        cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
        cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
        head_mask_type: Per-head mask type [num_heads], 1 for block sparse
        streaming_info: Optional streaming attention info
        base_blockmask: Block mask [batch, num_heads, q_blocks, k_blocks]
        max_seqlen_q_: Maximum Q sequence length
        max_seqlen_k_: Maximum K sequence length
        p_dropout: Dropout probability (0.0 for eval)
        deterministic: Whether to use deterministic algorithms
        softmax_scale: Softmax scale (default: 1/sqrt(head_dim))
        is_causal: Whether to apply causal masking
        exact_streaming: Whether to use exact streaming attention
        return_attn_probs: Whether to return attention probabilities

    Returns:
        Attention output [total_q, num_heads, head_dim]
    """
    head_mask_type, blocksparse_head_num = replace_ones_with_count(head_mask_type)
    if base_blockmask is not None:
        assert base_blockmask.shape[1] == blocksparse_head_num

    func = BlockSparseAttnFun if not return_attn_probs else BlockSparseAttnFunWithS
    return func.apply(
                q, k, v,
                cu_seqlens_q, cu_seqlens_k,
                128, 128,  # m_block_dim, n_block_dim (fixed at 128)
                head_mask_type,
                streaming_info,
                base_blockmask,
                max_seqlen_q_, max_seqlen_k_,
                p_dropout,
                softmax_scale,
                is_causal,
                exact_streaming,
                return_attn_probs,
                -1, -1,  # window_size_left, window_size_right
                deterministic
                )

Usage Example (from COMPASS)

from block_sparse_attn import block_sparse_attn_func

# After xattn_estimate returns sparse mask
attn_sums, approx_simple_mask = xattn_estimate(query_states, key_states, ...)

# Reshape for BSA (requires [seq_len, num_heads, head_dim] format)
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)

# Cumulative sequence lengths
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)

# Head mask type (1 for all heads using block sparse)
head_mask_type = torch.tensor([1] * num_heads, device=device, dtype=torch.int32)

# Call BSA
attn_output = block_sparse_attn_func(
    query_states,
    key_states,
    value_states,
    q_cu_seq_lens,
    k_cu_seq_lens,
    head_mask_type,
    None,  # streaming_info
    approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(),
    q_len,
    k_len,
    p_dropout=0.0,
    deterministic=True,
    is_causal=True,
)

# Reshape back to [batch, num_heads, seq_len, head_dim]
attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)

Key Constraints

  • Block size: Fixed at 128 tokens (hardcoded in BSA)
  • Batch size: Only batch_size=1 supported for block sparse mode
  • Mask format: [batch, num_heads, q_blocks, k_blocks] boolean tensor
  • Input format: [total_seq_len, num_heads, head_dim] (not batched)