Compare commits
2 Commits
690456dbf9
...
4cbd451af7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4cbd451af7 | ||
|
|
3aef6fc3a2 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -238,3 +238,4 @@ progress.md
|
||||
task_plan_*.md
|
||||
findings_*.md
|
||||
progress_*.md
|
||||
notes.md
|
||||
|
||||
@@ -15,6 +15,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
||||
| [`docs/sparse_policy_implementation_guide.md`](docs/sparse_policy_implementation_guide.md) | How to implement custom SparsePolicy: required methods, hooks, ring buffer pipeline pattern |
|
||||
| [`docs/sparse_attention_guide.md`](docs/sparse_attention_guide.md) | Block sparse attention methods (XAttention, FlexPrefill, MInference, AvgPool, Quest), computation flow, algorithms |
|
||||
| [`docs/xattention_algorithm_guide.md`](docs/xattention_algorithm_guide.md) | XAttention 算法详解: stride reshape、Triton kernels、BSA 依赖、块选择算法 |
|
||||
| [`docs/block_sparse_attn_interface.md`](docs/block_sparse_attn_interface.md) | BSA (Block Sparse Attention) 接口文档: 函数签名、使用示例、约束条件 |
|
||||
| [`docs/debugging_guide.md`](docs/debugging_guide.md) | PyTorch hooks for debugging, hook positions, tensor comparison, memory profiling |
|
||||
| [`docs/optimization_guide.md`](docs/optimization_guide.md) | Performance optimizations: sgDMA (15x), Triton merge (4.3x), N-way pipeline (2x) |
|
||||
| [`docs/known_issues.md`](docs/known_issues.md) | Documented bugs and fixes: partial last block bug, block size 4096 race condition |
|
||||
|
||||
103
DEBUG_SUMMARY.md
103
DEBUG_SUMMARY.md
@@ -1,103 +0,0 @@
|
||||
# Chunked Prefill Bug Debug Summary
|
||||
|
||||
## Problem
|
||||
`test_needle.py --enable-offload --input-len 8192` fails with garbage output.
|
||||
|
||||
The model generates completely wrong tokens instead of the expected "7492".
|
||||
|
||||
## Investigation Progress
|
||||
|
||||
### 1. Stream Synchronization Fix (Completed)
|
||||
- Replaced Triton `store_kvcache` kernel with pure PyTorch operations
|
||||
- Moved `store_kvcache` to `compute_stream` in chunked prefill mode
|
||||
- Added sync: `compute_stream.wait_event(offload_done)` after per-layer offload
|
||||
- Added sync: `default_stream.wait_stream(compute_stream)` before return
|
||||
|
||||
### 2. KV Cache Alignment Verification (Completed)
|
||||
Created alignment tests to compare K/V tensors between torch reference and nanovllm:
|
||||
|
||||
**RoPE Alignment:**
|
||||
- RoPE implementations match perfectly (max_diff=0.002, cosine ~1.0)
|
||||
- Confirmed RoPE is NOT the cause of the bug
|
||||
|
||||
**K/V Cache Alignment (Chunk 0):**
|
||||
- Cosine similarity: ~1.0 for all layers
|
||||
- Max diff: 2-7 (grows linearly with position, characteristic of FP16 precision)
|
||||
- Mean diff: < 0.001
|
||||
- **Conclusion: K/V cache offload is working correctly**
|
||||
|
||||
### 3. Layer Output Divergence Analysis (Completed)
|
||||
Created per-chunk layer output comparison:
|
||||
|
||||
**Chunk 0 (tokens 0-4096):**
|
||||
- All layers pass with excellent cosine similarity (0.999+)
|
||||
- Max diff grows in later layers but within acceptable range
|
||||
|
||||
**Chunk 1 (tokens 4096-8192):**
|
||||
- Layers 0-19: OK (cosine ~1.0)
|
||||
- Layers 20-27: Diverge (cosine 0.83-0.96, max_diff up to 114)
|
||||
- Divergence correlates with later transformer layers
|
||||
|
||||
### 4. Critical Discovery: Single-Chunk Offload Also Fails
|
||||
**Key finding:** Even with input_len=2048 (single chunk, no chunked attention), the model produces garbage output with CPU offload enabled.
|
||||
|
||||
```
|
||||
# Without offload: PASSES
|
||||
python tests/test_needle.py --input-len 2048
|
||||
# Output: "7492" (correct)
|
||||
|
||||
# With offload: FAILS
|
||||
python tests/test_needle.py --enable-offload --input-len 2048
|
||||
# Output: "The Ble White Th G Lopsiswin..." (garbage)
|
||||
```
|
||||
|
||||
**This proves the bug is NOT in:**
|
||||
- Chunked attention logic (merge_attention_outputs)
|
||||
- Multi-chunk KV loading
|
||||
- Ring buffer pipeline
|
||||
|
||||
**The bug IS in:**
|
||||
- The decode path when CPU offload is enabled
|
||||
- How prefilled KV is loaded/used during decode
|
||||
|
||||
### 5. Decode Path Analysis (In Progress)
|
||||
The decode path in CPU offload mode:
|
||||
1. Prefill writes KV to GPU, offloads to CPU
|
||||
2. Decode loads prefilled KV from CPU via `_decode_ring_buffer_pipeline`
|
||||
3. Attend to prefilled KV + accumulated decode tokens
|
||||
4. Merge results
|
||||
|
||||
**Observations:**
|
||||
- `prefilled_blocks` set is empty after decode (should contain block IDs)
|
||||
- CPU cache has valid data (reasonable mean/std values)
|
||||
- Decode buffer has zeros (decode tokens not being stored correctly?)
|
||||
|
||||
## Current Status
|
||||
|
||||
### Working
|
||||
- Stream synchronization fixes
|
||||
- K/V cache offload to CPU (verified alignment)
|
||||
- RoPE implementation
|
||||
- Chunked prefill attention for first chunk
|
||||
|
||||
### Not Working
|
||||
- Decode with CPU offload (even for single-chunk inputs)
|
||||
- Multi-chunk attention (divergence in later layers for chunk 1)
|
||||
|
||||
## Next Steps
|
||||
1. Debug why `prefilled_blocks` is empty after decode
|
||||
2. Check if decode path correctly loads KV from CPU
|
||||
3. Verify decode buffer is being written correctly
|
||||
4. Compare decode attention outputs between offload and non-offload modes
|
||||
|
||||
## Key Files
|
||||
- `nanovllm/layers/attention.py` - Main attention implementation with chunked paths
|
||||
- `nanovllm/kvcache/offload_engine.py` - CPU-GPU transfer engine
|
||||
- `nanovllm/kvcache/hybrid_manager.py` - KV cache management with `prefilled_blocks`
|
||||
- `nanovllm/engine/model_runner.py` - Prefill/decode orchestration
|
||||
|
||||
## Hypothesis
|
||||
The decode path fails because:
|
||||
1. `prefilled_blocks` is not being tracked correctly, causing `get_prefilled_cpu_blocks()` to return empty
|
||||
2. OR the decode attention is not correctly loading/using the prefilled KV from CPU
|
||||
3. OR there's a stream synchronization issue specific to decode path
|
||||
238
docs/block_sparse_attn_interface.md
Normal file
238
docs/block_sparse_attn_interface.md
Normal file
@@ -0,0 +1,238 @@
|
||||
# Block Sparse Attention Interface
|
||||
|
||||
Source: [MIT-HAN-LAB/Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention)
|
||||
|
||||
This document records the BSA (Block Sparse Attention) interface used by XAttention for sparse attention computation.
|
||||
|
||||
## Installation
|
||||
|
||||
BSA is installed in the `minference` conda environment:
|
||||
```
|
||||
/home/zijie/anaconda3/envs/minference/lib/python3.10/site-packages/block_sparse_attn/
|
||||
```
|
||||
|
||||
To use in other environments, add to PYTHONPATH:
|
||||
```bash
|
||||
PYTHONPATH=/home/zijie/anaconda3/envs/minference/lib/python3.10/site-packages:$PYTHONPATH python script.py
|
||||
```
|
||||
|
||||
## Interface Code
|
||||
|
||||
```python
|
||||
# Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_blocksparse_attn_interface.py
|
||||
|
||||
import block_sparse_attn_cuda
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def convert_blockmask(blockmask, causal):
|
||||
"""Convert from the 0-1 format to the format used by the CUDA code.
|
||||
0 means the block is skipped.
|
||||
nonzero means the block is not skipped.
|
||||
Argument:
|
||||
blockmask: (row, col): a 0-1 tensor
|
||||
Return:
|
||||
blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row
|
||||
indices of the nonzero blocks, padded with -1 to reach length @row.
|
||||
The indices are multiplied by 4, with the smallest bit used to encode whether
|
||||
it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is
|
||||
the last nonzero in its row..
|
||||
"""
|
||||
assert not causal
|
||||
nrow, ncol = blockmask.shape
|
||||
# Sort does not support bool on CUDA
|
||||
blockmask = blockmask.to(dtype=torch.uint8)
|
||||
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True)
|
||||
nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0)
|
||||
last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1]
|
||||
last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
|
||||
torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row
|
||||
]
|
||||
first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0]
|
||||
first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[
|
||||
torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row
|
||||
]
|
||||
nonzero_idx = nonzero_sorted_rowidx * 4
|
||||
nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2
|
||||
nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1
|
||||
nonzero_idx[nonzero_val == 0] = -1
|
||||
return nonzero_idx.T.contiguous().to(dtype=torch.int32)
|
||||
|
||||
|
||||
def convert_blockmask_row_reverse(blockmask, causal=False):
|
||||
blockmask = blockmask.to(dtype=torch.uint8)
|
||||
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-1, stable=True, descending=False)
|
||||
|
||||
nonzero_idx = nonzero_sorted_rowidx
|
||||
nonzero_idx[nonzero_val == 0] = -1
|
||||
nonzero_idx = torch.flip(nonzero_idx, dims=[-1])
|
||||
|
||||
return nonzero_idx.contiguous().to(dtype=torch.int32)
|
||||
|
||||
|
||||
def convert_blockmask_col_reverse(blockmask, causal=False):
|
||||
blockmask = blockmask.to(dtype=torch.uint8)
|
||||
nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=-2, stable=True, descending=False)
|
||||
|
||||
nonzero_idx = nonzero_sorted_rowidx
|
||||
nonzero_idx[nonzero_val == 0] = -1
|
||||
nonzero_idx = torch.flip(nonzero_idx, dims=[-2])
|
||||
nonzero_idx = torch.transpose(nonzero_idx, -1, -2)
|
||||
|
||||
return nonzero_idx.contiguous().to(dtype=torch.int32)
|
||||
|
||||
|
||||
def replace_ones_with_count(tensor):
|
||||
ones_mask = tensor == 1
|
||||
ones_num = ones_mask.sum()
|
||||
count = torch.cumsum(ones_mask, dim=-1).to(tensor.dtype)
|
||||
count = count * ones_mask
|
||||
tensor = tensor.masked_scatter(ones_mask, count[ones_mask])
|
||||
return tensor, ones_num
|
||||
|
||||
|
||||
def _block_sparse_attn_forward(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k,
|
||||
m_block_dim, n_block_dim,
|
||||
head_mask_type,
|
||||
streaming_info,
|
||||
row_blockmask,
|
||||
max_seqlen_q_, max_seqlen_k_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
exact_streaming,
|
||||
return_softmax,
|
||||
window_size_left,
|
||||
window_size_right
|
||||
):
|
||||
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = block_sparse_attn_cuda.fwd_block(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k,
|
||||
m_block_dim, n_block_dim,
|
||||
head_mask_type,
|
||||
streaming_info,
|
||||
row_blockmask,
|
||||
max_seqlen_q_, max_seqlen_k_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
exact_streaming,
|
||||
return_softmax,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
None
|
||||
)
|
||||
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
|
||||
|
||||
|
||||
def block_sparse_attn_func(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k,
|
||||
head_mask_type,
|
||||
streaming_info,
|
||||
base_blockmask,
|
||||
max_seqlen_q_, max_seqlen_k_,
|
||||
p_dropout,
|
||||
deterministic=False,
|
||||
softmax_scale=None,
|
||||
is_causal=False,
|
||||
exact_streaming=False,
|
||||
return_attn_probs=False,
|
||||
):
|
||||
"""
|
||||
Main entry point for block sparse attention.
|
||||
|
||||
Args:
|
||||
q: Query tensor [total_q, num_heads, head_dim]
|
||||
k: Key tensor [total_k, num_heads, head_dim]
|
||||
v: Value tensor [total_k, num_heads, head_dim]
|
||||
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
||||
cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
|
||||
head_mask_type: Per-head mask type [num_heads], 1 for block sparse
|
||||
streaming_info: Optional streaming attention info
|
||||
base_blockmask: Block mask [batch, num_heads, q_blocks, k_blocks]
|
||||
max_seqlen_q_: Maximum Q sequence length
|
||||
max_seqlen_k_: Maximum K sequence length
|
||||
p_dropout: Dropout probability (0.0 for eval)
|
||||
deterministic: Whether to use deterministic algorithms
|
||||
softmax_scale: Softmax scale (default: 1/sqrt(head_dim))
|
||||
is_causal: Whether to apply causal masking
|
||||
exact_streaming: Whether to use exact streaming attention
|
||||
return_attn_probs: Whether to return attention probabilities
|
||||
|
||||
Returns:
|
||||
Attention output [total_q, num_heads, head_dim]
|
||||
"""
|
||||
head_mask_type, blocksparse_head_num = replace_ones_with_count(head_mask_type)
|
||||
if base_blockmask is not None:
|
||||
assert base_blockmask.shape[1] == blocksparse_head_num
|
||||
|
||||
func = BlockSparseAttnFun if not return_attn_probs else BlockSparseAttnFunWithS
|
||||
return func.apply(
|
||||
q, k, v,
|
||||
cu_seqlens_q, cu_seqlens_k,
|
||||
128, 128, # m_block_dim, n_block_dim (fixed at 128)
|
||||
head_mask_type,
|
||||
streaming_info,
|
||||
base_blockmask,
|
||||
max_seqlen_q_, max_seqlen_k_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
exact_streaming,
|
||||
return_attn_probs,
|
||||
-1, -1, # window_size_left, window_size_right
|
||||
deterministic
|
||||
)
|
||||
```
|
||||
|
||||
## Usage Example (from COMPASS)
|
||||
|
||||
```python
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
|
||||
# After xattn_estimate returns sparse mask
|
||||
attn_sums, approx_simple_mask = xattn_estimate(query_states, key_states, ...)
|
||||
|
||||
# Reshape for BSA (requires [seq_len, num_heads, head_dim] format)
|
||||
query_states = query_states.transpose(1, 2).view(q_len, num_heads, head_dim)
|
||||
key_states = key_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
||||
value_states = value_states.transpose(1, 2).view(k_len, num_heads, head_dim)
|
||||
|
||||
# Cumulative sequence lengths
|
||||
q_cu_seq_lens = torch.tensor([0, q_len], dtype=torch.int32, device=device)
|
||||
k_cu_seq_lens = torch.tensor([0, k_len], dtype=torch.int32, device=device)
|
||||
|
||||
# Head mask type (1 for all heads using block sparse)
|
||||
head_mask_type = torch.tensor([1] * num_heads, device=device, dtype=torch.int32)
|
||||
|
||||
# Call BSA
|
||||
attn_output = block_sparse_attn_func(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
q_cu_seq_lens,
|
||||
k_cu_seq_lens,
|
||||
head_mask_type,
|
||||
None, # streaming_info
|
||||
approx_simple_mask[:, :, :q_block_num, :k_block_num].contiguous(),
|
||||
q_len,
|
||||
k_len,
|
||||
p_dropout=0.0,
|
||||
deterministic=True,
|
||||
is_causal=True,
|
||||
)
|
||||
|
||||
# Reshape back to [batch, num_heads, seq_len, head_dim]
|
||||
attn_output = attn_output.view(batch_size, q_len, num_heads, head_dim).transpose(1, 2)
|
||||
```
|
||||
|
||||
## Key Constraints
|
||||
|
||||
- **Block size**: Fixed at 128 tokens (hardcoded in BSA)
|
||||
- **Batch size**: Only batch_size=1 supported for block sparse mode
|
||||
- **Mask format**: `[batch, num_heads, q_blocks, k_blocks]` boolean tensor
|
||||
- **Input format**: `[total_seq_len, num_heads, head_dim]` (not batched)
|
||||
@@ -11,9 +11,26 @@ from nanovllm.ops.chunked_attention import (
|
||||
ChunkedPrefillState,
|
||||
)
|
||||
|
||||
from nanovllm.ops.xattn import (
|
||||
xattn_estimate,
|
||||
flat_group_gemm_fuse_reshape,
|
||||
softmax_fuse_block_sum,
|
||||
find_blocks_chunked,
|
||||
create_causal_mask,
|
||||
compute_sparsity,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# chunked_attention
|
||||
"flash_attn_with_lse",
|
||||
"merge_attention_outputs",
|
||||
"chunked_attention_varlen",
|
||||
"ChunkedPrefillState",
|
||||
# xattn
|
||||
"xattn_estimate",
|
||||
"flat_group_gemm_fuse_reshape",
|
||||
"softmax_fuse_block_sum",
|
||||
"find_blocks_chunked",
|
||||
"create_causal_mask",
|
||||
"compute_sparsity",
|
||||
]
|
||||
|
||||
952
nanovllm/ops/xattn.py
Normal file
952
nanovllm/ops/xattn.py
Normal file
@@ -0,0 +1,952 @@
|
||||
"""
|
||||
XAttention block importance estimation with Triton kernels.
|
||||
|
||||
Ported from COMPASS project (compass/src/Xattention.py, kernels.py, utils.py).
|
||||
|
||||
This module implements the ESTIMATE phase of XAttention, which identifies
|
||||
important blocks using stride-interleaved Q/K reshaping and Triton kernels.
|
||||
|
||||
Architecture:
|
||||
XAttention = Estimate (Triton) + Compute (BSA)
|
||||
This module: Estimate only
|
||||
BSA library: block_sparse_attn (external dependency for compute)
|
||||
|
||||
Key functions:
|
||||
- xattn_estimate: Estimate block importance and generate sparse mask
|
||||
- flat_group_gemm_fuse_reshape: Fused stride reshape + GEMM kernel
|
||||
- softmax_fuse_block_sum: Online softmax + block-wise sum kernel
|
||||
- find_blocks_chunked: Block selection based on cumulative threshold
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Triton Kernels
|
||||
# ============================================================
|
||||
|
||||
@triton.jit
|
||||
def softmax_fuse_block_sum_kernel_causal(
|
||||
In,
|
||||
Out,
|
||||
scale,
|
||||
input_stride_0,
|
||||
input_stride_1,
|
||||
input_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
output_stride_2,
|
||||
real_q_len,
|
||||
k_len, # we assume k_len is divisible by segment_size
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
segment_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fused softmax + block sum kernel with causal masking.
|
||||
|
||||
This kernel performs online softmax on attention weights and sums
|
||||
within each block, producing block-level attention scores.
|
||||
|
||||
Algorithm:
|
||||
1. Two-pass online softmax (compute max, then normalize)
|
||||
2. Apply causal mask (future positions get -inf)
|
||||
3. Reshape to blocks and sum within each block
|
||||
|
||||
Args (via grid):
|
||||
block_id: Current Q block index
|
||||
head_id: Attention head index
|
||||
batch_id: Batch index
|
||||
|
||||
Input shape: [batch, heads, q_len, k_len]
|
||||
Output shape: [batch, heads, q_blocks, k_blocks]
|
||||
"""
|
||||
block_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
batch_id = tl.program_id(2)
|
||||
|
||||
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||
offs_k = tl.arange(0, segment_size)
|
||||
|
||||
num_iters = k_len // segment_size
|
||||
num_iters_before_causal = (chunk_start + (block_id + 1) * block_size - 1) // segment_size
|
||||
|
||||
# Online softmax state
|
||||
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf") # running max
|
||||
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0 # running sum
|
||||
|
||||
# Input pointer setup
|
||||
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||
|
||||
# Output pointer setup
|
||||
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
||||
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
||||
|
||||
# Pass 1: Compute global max and sum (before causal boundary)
|
||||
for iter in range(0, num_iters_before_causal):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
m_local = tl.max(X, 1)
|
||||
m_new = tl.maximum(m_i, m_local)
|
||||
alpha = tl.math.exp2(m_i - m_new)
|
||||
|
||||
X = X - m_new[:, None]
|
||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||
l_i = l_i * alpha + l_local
|
||||
|
||||
m_i = m_new
|
||||
|
||||
# Pass 1 continued: Handle causal boundary
|
||||
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
|
||||
X = tl.where(mask, X, -1.0e6)
|
||||
m_local = tl.max(X, 1)
|
||||
m_new = tl.maximum(m_i, m_local)
|
||||
alpha = tl.math.exp2(m_i - m_new)
|
||||
|
||||
X = X - m_new[:, None]
|
||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||
l_i = l_i * alpha + l_local
|
||||
|
||||
m_i = m_new
|
||||
|
||||
l_i_inv = 1.0 / l_i
|
||||
|
||||
sum_mask = offs_q[:, None] < real_q_len
|
||||
|
||||
# Pass 2: Normalize and compute block sums (before causal boundary)
|
||||
for iter in range(0, num_iters_before_causal):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||
X = tl.where(sum_mask, X, 0)
|
||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||
X = tl.sum(X, 2)
|
||||
X = tl.sum(X, 0)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
# Pass 2 continued: Handle causal boundary
|
||||
for iter in range(num_iters_before_causal, num_iters_before_causal + 1):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
mask = offs_q[:, None] >= (offs_k[None, :] + iter * segment_size)
|
||||
X = tl.where(mask, X, -1.0e6)
|
||||
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||
X = tl.where(sum_mask, X, 0)
|
||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||
X = tl.sum(X, 2)
|
||||
X = tl.sum(X, 0)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
# Pass 2 continued: Zero out future blocks
|
||||
for iter in range(num_iters_before_causal + 1, num_iters):
|
||||
X = tl.zeros([segment_size // block_size], dtype=tl.float32)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def softmax_fuse_block_sum_kernel_non_causal(
|
||||
In,
|
||||
Out,
|
||||
scale,
|
||||
input_stride_0,
|
||||
input_stride_1,
|
||||
input_stride_2,
|
||||
output_stride_0,
|
||||
output_stride_1,
|
||||
output_stride_2,
|
||||
real_q_len,
|
||||
k_len, # we assume k_len is divisible by segment_size
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
segment_size: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fused softmax + block sum kernel without causal masking.
|
||||
|
||||
Same as causal version but without causal mask application.
|
||||
"""
|
||||
block_id = tl.program_id(0)
|
||||
head_id = tl.program_id(1)
|
||||
batch_id = tl.program_id(2)
|
||||
|
||||
offs_q = tl.arange(0, block_size) + chunk_start + block_id * block_size
|
||||
offs_k = tl.arange(0, segment_size)
|
||||
|
||||
num_iters = k_len // segment_size
|
||||
|
||||
m_i = tl.zeros([block_size], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([block_size], dtype=tl.float32) + 1.0
|
||||
|
||||
input_ptr = In + batch_id * input_stride_0 + head_id * input_stride_1 + block_id * block_size * input_stride_2
|
||||
input_ptr = input_ptr + tl.arange(0, segment_size) + tl.arange(0, block_size)[:, None] * input_stride_2
|
||||
|
||||
output_ptr = Out + batch_id * output_stride_0 + head_id * output_stride_1 + block_id * output_stride_2
|
||||
output_ptr = output_ptr + tl.arange(0, segment_size // block_size)
|
||||
|
||||
# Pass 1: Compute global max and sum
|
||||
for iter in range(0, num_iters):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
m_local = tl.max(X, 1)
|
||||
m_new = tl.maximum(m_i, m_local)
|
||||
alpha = tl.math.exp2(m_i - m_new)
|
||||
|
||||
X = X - m_new[:, None]
|
||||
l_local = tl.sum(tl.math.exp2(X), 1)
|
||||
l_i = l_i * alpha + l_local
|
||||
|
||||
m_i = m_new
|
||||
|
||||
l_i_inv = 1.0 / l_i
|
||||
|
||||
sum_mask = offs_q[:, None] < real_q_len
|
||||
|
||||
# Pass 2: Normalize and compute block sums
|
||||
for iter in range(0, num_iters):
|
||||
X = tl.load(input_ptr + iter * segment_size).to(tl.float32) * scale
|
||||
X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||
X = tl.where(sum_mask, X, 0)
|
||||
X = tl.reshape(X, (block_size, segment_size // block_size, block_size))
|
||||
X = tl.sum(X, 2)
|
||||
X = tl.sum(X, 0)
|
||||
tl.store(output_ptr + iter * segment_size // block_size, X.to(Out.type.element_ty))
|
||||
|
||||
|
||||
@triton.jit
|
||||
def flat_group_gemm_fuse_reshape_kernel(
|
||||
Q, K, Out,
|
||||
stride_qz, stride_qh, stride_qn,
|
||||
stride_kz, stride_kh, stride_kn,
|
||||
stride_oz, stride_oh, stride_on,
|
||||
chunk_start, chunk_end,
|
||||
H: tl.constexpr,
|
||||
STRIDE: tl.constexpr,
|
||||
HEAD_DIM: tl.constexpr,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
is_causal: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Fused stride reshape + GEMM kernel.
|
||||
|
||||
This kernel computes Q_reshaped @ K_reshaped^T without explicitly
|
||||
creating the reshaped tensors, saving memory and bandwidth.
|
||||
|
||||
Stride reshape (inverse mode):
|
||||
- K: concat([K[:,:,k::stride,:] for k in range(stride)])
|
||||
- Q: concat([Q[:,:,(stride-1-q)::stride,:] for q in range(stride)])
|
||||
|
||||
The kernel simulates this by adjusting pointer arithmetic:
|
||||
- Q samples backwards: Q_ptrs starts at (stride-1), steps by -1
|
||||
- K samples forwards: K_ptrs starts at 0, steps by +1
|
||||
- Both accumulate across stride iterations
|
||||
|
||||
Args (via grid):
|
||||
block_m: Q block index (in reshaped space)
|
||||
block_n: K block index (in reshaped space)
|
||||
batch_id * H + head_id: Combined batch and head index
|
||||
|
||||
Input shapes:
|
||||
Q: [batch, heads, q_len, head_dim]
|
||||
K: [batch, heads, k_len, head_dim]
|
||||
|
||||
Output shape: [batch, heads, q_len/stride, k_len/stride]
|
||||
"""
|
||||
block_m = tl.program_id(0).to(tl.int64)
|
||||
block_n = tl.program_id(1).to(tl.int64)
|
||||
batch_id = tl.program_id(2).to(tl.int64) // H
|
||||
head_id = tl.program_id(2).to(tl.int64) % H
|
||||
|
||||
# Early exit for causal: skip blocks where K is entirely in the future
|
||||
if is_causal:
|
||||
if chunk_start + (block_m + 1) * BLOCK_M <= block_n * BLOCK_N:
|
||||
return
|
||||
|
||||
# Q pointer: sample from (stride-1) position, step backwards
|
||||
Q_ptrs = Q + batch_id * stride_qz + head_id * stride_qh + block_m * BLOCK_M * STRIDE * stride_qn
|
||||
Q_ptrs = Q_ptrs + tl.arange(0, BLOCK_M)[:, None] * (stride_qn * STRIDE) + tl.arange(0, HEAD_DIM)[None, :] + stride_qn * (STRIDE - 1)
|
||||
|
||||
# K pointer: sample from 0 position, step forwards
|
||||
K_ptrs = K + batch_id * stride_kz + head_id * stride_kh + block_n * BLOCK_N * STRIDE * stride_kn
|
||||
K_ptrs = K_ptrs + tl.arange(0, BLOCK_N)[None, :] * (stride_kn * STRIDE) + tl.arange(0, HEAD_DIM)[:, None]
|
||||
|
||||
o = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
||||
|
||||
# Accumulate Q @ K^T across stride positions
|
||||
for iter in range(STRIDE):
|
||||
q = tl.load(Q_ptrs - iter * stride_qn) # Q steps backwards
|
||||
k = tl.load(K_ptrs + iter * stride_kn) # K steps forwards
|
||||
o += tl.dot(q, k)
|
||||
|
||||
# Store output
|
||||
O_ptrs = Out + batch_id * stride_oz + head_id * stride_oh + block_m * BLOCK_M * stride_on + block_n * BLOCK_N
|
||||
O_ptrs = O_ptrs + tl.arange(0, BLOCK_M)[:, None] * stride_on + tl.arange(0, BLOCK_N)[None, :]
|
||||
|
||||
tl.store(O_ptrs, o.to(Out.type.element_ty))
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Triton Kernel Wrappers
|
||||
# ============================================================
|
||||
|
||||
def softmax_fuse_block_sum(
|
||||
attn_weights_slice: torch.Tensor,
|
||||
reshaped_block_size: int,
|
||||
segment_size: int,
|
||||
chunk_start: int,
|
||||
chunk_end: int,
|
||||
real_q_len: int,
|
||||
scale: float,
|
||||
is_causal: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute softmax and block-wise sum of attention weights.
|
||||
|
||||
This function takes raw QK^T scores (after stride reshape),
|
||||
applies softmax, and sums within each block to produce
|
||||
block-level attention scores.
|
||||
|
||||
Args:
|
||||
attn_weights_slice: Raw attention scores [batch, heads, q_len, k_len]
|
||||
reshaped_block_size: Block size in reshaped space (block_size / stride)
|
||||
segment_size: Processing segment size
|
||||
chunk_start: Start position for this chunk
|
||||
chunk_end: End position for this chunk
|
||||
real_q_len: Actual Q length (before padding)
|
||||
scale: Softmax scale factor (includes 1/sqrt(d) and stride normalization)
|
||||
is_causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
Block-level attention sums [batch, heads, q_blocks, k_blocks]
|
||||
"""
|
||||
batch_size, num_heads, q_len, k_len = attn_weights_slice.shape
|
||||
|
||||
assert q_len % reshaped_block_size == 0, f"q_len {q_len} must be divisible by reshaped_block_size {reshaped_block_size}"
|
||||
assert k_len % segment_size == 0, f"k_len {k_len} must be divisible by segment_size {segment_size}"
|
||||
assert segment_size % reshaped_block_size == 0, f"segment_size {segment_size} must be divisible by reshaped_block_size {reshaped_block_size}"
|
||||
assert attn_weights_slice.stride(-1) == 1, "Last dimension must be contiguous"
|
||||
|
||||
output = torch.empty(
|
||||
(batch_size, num_heads, q_len // reshaped_block_size, k_len // reshaped_block_size),
|
||||
dtype=attn_weights_slice.dtype,
|
||||
device=attn_weights_slice.device
|
||||
)
|
||||
|
||||
grid = (q_len // reshaped_block_size, num_heads, batch_size)
|
||||
|
||||
if is_causal:
|
||||
softmax_fuse_block_sum_kernel_causal[grid](
|
||||
attn_weights_slice,
|
||||
output,
|
||||
scale,
|
||||
attn_weights_slice.stride(0),
|
||||
attn_weights_slice.stride(1),
|
||||
attn_weights_slice.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
real_q_len,
|
||||
k_len,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
segment_size,
|
||||
reshaped_block_size,
|
||||
)
|
||||
else:
|
||||
softmax_fuse_block_sum_kernel_non_causal[grid](
|
||||
attn_weights_slice,
|
||||
output,
|
||||
scale,
|
||||
attn_weights_slice.stride(0),
|
||||
attn_weights_slice.stride(1),
|
||||
attn_weights_slice.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
real_q_len,
|
||||
k_len,
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
segment_size,
|
||||
reshaped_block_size,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def flat_group_gemm_fuse_reshape(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
stride: int,
|
||||
chunk_start: int,
|
||||
chunk_end: int,
|
||||
is_causal: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute fused stride reshape + GEMM for Q @ K^T.
|
||||
|
||||
This is the core estimation kernel of XAttention. It computes
|
||||
attention scores between strided Q and K without explicitly
|
||||
creating the reshaped tensors.
|
||||
|
||||
The stride reshape (inverse mode) works as:
|
||||
- K_reshaped: concat([K[:,:,k::stride,:] for k in range(stride)])
|
||||
- Q_reshaped: concat([Q[:,:,(stride-1-q)::stride,:] for q in range(stride)])
|
||||
|
||||
Result: Q_reshaped @ K_reshaped^T with shape [batch, heads, q_len/stride, k_len/stride]
|
||||
|
||||
Args:
|
||||
query_states: Q tensor [batch, heads, q_len, head_dim]
|
||||
key_states: K tensor [batch, heads, k_len, head_dim]
|
||||
stride: Stride for reshape (typically 8)
|
||||
chunk_start: Start position (in reshaped space) for causal masking
|
||||
chunk_end: End position (in reshaped space) for causal masking
|
||||
is_causal: Whether to apply causal masking (skip future blocks)
|
||||
|
||||
Returns:
|
||||
Attention scores [batch, heads, q_len/stride, k_len/stride]
|
||||
"""
|
||||
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||
kv_len = key_states.shape[2]
|
||||
|
||||
assert key_states.shape[0] == batch_size
|
||||
assert key_states.shape[1] == num_heads
|
||||
assert key_states.shape[3] == head_dim
|
||||
|
||||
output = torch.empty(
|
||||
(batch_size, num_heads, q_len // stride, kv_len // stride),
|
||||
dtype=query_states.dtype,
|
||||
device=query_states.device
|
||||
)
|
||||
|
||||
# Adjust block size based on GPU shared memory
|
||||
# RTX 3090 has ~100KB, A100/H100 have ~160KB+
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
if props.total_memory < 30 * 1024**3: # Less than 30GB (e.g., RTX 3090 24GB)
|
||||
BLOCK_M = 64
|
||||
BLOCK_N = 64
|
||||
else:
|
||||
BLOCK_M = 128
|
||||
BLOCK_N = 128
|
||||
|
||||
assert q_len % (stride * BLOCK_M) == 0, f"q_len {q_len} must be divisible by stride*BLOCK_M {stride * BLOCK_M}"
|
||||
assert kv_len % (stride * BLOCK_N) == 0, f"kv_len {kv_len} must be divisible by stride*BLOCK_N {stride * BLOCK_N}"
|
||||
|
||||
grid = (q_len // stride // BLOCK_M, kv_len // stride // BLOCK_N, batch_size * num_heads)
|
||||
|
||||
flat_group_gemm_fuse_reshape_kernel[grid](
|
||||
query_states,
|
||||
key_states,
|
||||
output,
|
||||
query_states.stride(0),
|
||||
query_states.stride(1),
|
||||
query_states.stride(2),
|
||||
key_states.stride(0),
|
||||
key_states.stride(1),
|
||||
key_states.stride(2),
|
||||
output.stride(0),
|
||||
output.stride(1),
|
||||
output.stride(2),
|
||||
chunk_start,
|
||||
chunk_end,
|
||||
num_heads,
|
||||
stride,
|
||||
head_dim,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
is_causal,
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Block Selection Utilities
|
||||
# ============================================================
|
||||
|
||||
def find_blocks_chunked(
|
||||
input_tensor: torch.Tensor,
|
||||
current_index: int,
|
||||
threshold: float,
|
||||
num_to_choose: Optional[int],
|
||||
decoding: bool,
|
||||
mode: str = "both",
|
||||
causal: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Select important blocks based on cumulative attention threshold.
|
||||
|
||||
This function takes block-level attention scores and selects blocks
|
||||
that cumulatively account for a specified fraction of total attention.
|
||||
|
||||
Algorithm:
|
||||
1. Compute total attention per query block
|
||||
2. Sort blocks by attention score (descending)
|
||||
3. Accumulate until reaching threshold * total
|
||||
4. Mark accumulated blocks as selected
|
||||
5. Always keep diagonal blocks (for causal) and sink block
|
||||
|
||||
Args:
|
||||
input_tensor: Block attention scores [batch, heads, q_blocks, k_blocks]
|
||||
current_index: Current chunk's starting block index
|
||||
threshold: Cumulative attention threshold (e.g., 0.9 = keep 90% attention mass)
|
||||
num_to_choose: Alternative to threshold - select fixed number of blocks
|
||||
decoding: Whether in decode mode (vs prefill)
|
||||
mode: "prefill", "decode", or "both"
|
||||
causal: Whether to apply causal masking
|
||||
|
||||
Returns:
|
||||
Boolean mask [batch, heads, q_blocks, k_blocks] indicating selected blocks
|
||||
"""
|
||||
assert threshold is None or num_to_choose is None, "Only one of threshold or num_to_choose can be specified"
|
||||
|
||||
batch_size, head_num, chunk_num, block_num = input_tensor.shape
|
||||
|
||||
# Special case: prefill mode during decoding - return all True
|
||||
if mode == "prefill" and decoding:
|
||||
return torch.ones_like(input_tensor, dtype=torch.bool)
|
||||
|
||||
# Special case: decode mode during prefill
|
||||
if mode == "decode" and not decoding:
|
||||
mask = torch.ones_like(input_tensor, dtype=torch.bool)
|
||||
if causal:
|
||||
mask[:, :, :, current_index : current_index + chunk_num] = torch.tril(
|
||||
torch.ones(1, head_num, chunk_num, chunk_num, device=input_tensor.device)
|
||||
)
|
||||
mask[:, :, current_index + chunk_num :, :] = 0
|
||||
return torch.cat(
|
||||
[
|
||||
torch.ones_like(input_tensor, dtype=torch.bool)[:, :, 0 : current_index + 1],
|
||||
torch.zeros_like(input_tensor, dtype=torch.bool)[:, :, current_index + 1 :],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
else:
|
||||
return mask
|
||||
|
||||
# Convert to float for numerical operations
|
||||
input_tensor = input_tensor.to(torch.float32)
|
||||
|
||||
if threshold is not None:
|
||||
# Compute required cumulative sum
|
||||
total_sum = input_tensor.sum(dim=-1, keepdim=True)
|
||||
if isinstance(threshold, torch.Tensor):
|
||||
threshold = threshold.to(torch.float32)
|
||||
required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze(-1).expand(
|
||||
(batch_size, head_num, chunk_num, 1)
|
||||
).to(input_tensor.device)
|
||||
else:
|
||||
required_sum = total_sum * threshold
|
||||
|
||||
if causal:
|
||||
# Initialize mask with mandatory blocks
|
||||
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
||||
mask[:, :, :, 0] = True # Sink block always selected
|
||||
|
||||
# Diagonal blocks (current chunk's causal positions)
|
||||
mask[:, :, :, current_index : current_index + chunk_num] = (
|
||||
torch.eye(chunk_num, device=mask.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.expand(1, head_num, chunk_num, chunk_num)
|
||||
)
|
||||
|
||||
# Mask out mandatory blocks for sorting
|
||||
other_values = input_tensor.masked_fill(mask, 0)
|
||||
sorted_values, _ = torch.sort(other_values, dim=-1, descending=True)
|
||||
sorted_values = sorted_values.to(input_tensor.device)
|
||||
|
||||
# Prepend mandatory blocks' contribution
|
||||
sorted_values = torch.cat(
|
||||
[
|
||||
torch.zeros((batch_size, head_num, chunk_num, 1), device=input_tensor.device),
|
||||
torch.where(mask, input_tensor, 0).sum(dim=-1, keepdim=True),
|
||||
sorted_values[:, :, :, :-2],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Get sorted indices (mandatory blocks get high priority)
|
||||
_, index = torch.sort(
|
||||
torch.where(mask, 100000 * (1 + input_tensor), input_tensor),
|
||||
dim=-1,
|
||||
descending=True,
|
||||
)
|
||||
|
||||
# Compute cumulative sum (excluding current block)
|
||||
cumulative_sum_without_self = torch.cat(
|
||||
[
|
||||
torch.zeros((batch_size, head_num, chunk_num, 1), device=input_tensor.device),
|
||||
sorted_values[:, :, :, 0:-1],
|
||||
],
|
||||
dim=-1,
|
||||
).cumsum(dim=-1)
|
||||
|
||||
# Select blocks until threshold is reached
|
||||
index_mask = cumulative_sum_without_self < required_sum
|
||||
index = torch.where(index_mask, index, 0)
|
||||
|
||||
# Flatten for scatter operation
|
||||
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
||||
index = index.view(batch_size, head_num * chunk_num, block_num)
|
||||
|
||||
# Mark selected blocks
|
||||
mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True
|
||||
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
||||
else:
|
||||
# Non-causal: simple threshold-based selection
|
||||
mask = torch.zeros_like(input_tensor, dtype=torch.bool)
|
||||
sorted_values, index = torch.sort(input_tensor, dim=-1, descending=True)
|
||||
sorted_values = sorted_values.to(input_tensor.device)
|
||||
|
||||
cumulative_sum_without_self = torch.cat(
|
||||
[
|
||||
torch.zeros((batch_size, head_num, chunk_num, 1), device=input_tensor.device),
|
||||
sorted_values[:, :, :, 0:-1],
|
||||
],
|
||||
dim=-1,
|
||||
).cumsum(dim=-1)
|
||||
|
||||
index_mask = cumulative_sum_without_self < required_sum
|
||||
index = torch.where(index_mask, index, 0)
|
||||
|
||||
mask = mask.view(batch_size, head_num * chunk_num, block_num)
|
||||
index = index.view(batch_size, head_num * chunk_num, block_num)
|
||||
mask[:, torch.arange(mask.shape[1], device=mask.device).unsqueeze(dim=-1), index] = True
|
||||
mask = mask.view(batch_size, head_num, chunk_num, block_num)
|
||||
else:
|
||||
raise NotImplementedError("Block num selection (num_to_choose) not implemented")
|
||||
|
||||
# Enforce causal: zero out future blocks
|
||||
try:
|
||||
if causal:
|
||||
assert (~mask[:, :, :, current_index + chunk_num :]).all()
|
||||
except:
|
||||
mask[:, :, :, current_index + chunk_num :] = False
|
||||
|
||||
# Validation
|
||||
if causal:
|
||||
if decoding:
|
||||
assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all()
|
||||
else:
|
||||
lambda_mask = torch.zeros_like(input_tensor, dtype=bool, device=input_tensor.device)
|
||||
lambda_mask[:, :, :, 0] = True
|
||||
lambda_mask[:, :, :, current_index : current_index + chunk_num] = (
|
||||
torch.eye(chunk_num, device=lambda_mask.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(0)
|
||||
.expand(1, head_num, chunk_num, chunk_num)
|
||||
)
|
||||
assert torch.where(lambda_mask, mask, True).all()
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def create_causal_mask(
|
||||
batch_size: int,
|
||||
head_num: int,
|
||||
block_size: int,
|
||||
block_num: int,
|
||||
divide_block_num: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Create a causal attention mask for block-level attention.
|
||||
|
||||
Args:
|
||||
batch_size: Batch size
|
||||
head_num: Number of attention heads
|
||||
block_size: Tokens per block
|
||||
block_num: Total number of blocks
|
||||
divide_block_num: Block index at which causality boundary is applied
|
||||
|
||||
Returns:
|
||||
Causal mask [batch, heads, block_size, block_size * block_num]
|
||||
"""
|
||||
divide_block_num += 1
|
||||
if divide_block_num < 1 or divide_block_num > block_num:
|
||||
raise ValueError(
|
||||
f"divide_block_num ({divide_block_num}) must be between 1 and block_num ({block_num})."
|
||||
)
|
||||
|
||||
total_size = block_size * block_num
|
||||
device = "cuda"
|
||||
mask = torch.zeros(block_size, total_size, device=device)
|
||||
|
||||
# Mask future blocks
|
||||
if divide_block_num < block_num:
|
||||
mask[:, divide_block_num * block_size :] = float("-inf")
|
||||
|
||||
# Apply triangular mask at causality boundary
|
||||
if divide_block_num - 1 < block_num:
|
||||
start_col = (divide_block_num - 1) * block_size
|
||||
end_col = start_col + block_size
|
||||
upper_tri_mask = torch.triu(
|
||||
torch.full((block_size, block_size), float("-inf"), device=device),
|
||||
diagonal=1,
|
||||
)
|
||||
mask[:, start_col:end_col] = upper_tri_mask
|
||||
|
||||
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||
mask = mask.expand(batch_size, head_num, block_size, total_size)
|
||||
return mask
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Estimation Function
|
||||
# ============================================================
|
||||
|
||||
def xattn_estimate(
|
||||
query_states: torch.Tensor,
|
||||
key_states: torch.Tensor,
|
||||
block_size: int = 128,
|
||||
stride: int = 8,
|
||||
norm: float = 1.0,
|
||||
threshold: float = 0.9,
|
||||
chunk_size: int = 16384,
|
||||
use_triton: bool = True,
|
||||
causal: bool = True,
|
||||
keep_sink: bool = False,
|
||||
keep_recent: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Estimate block importance for XAttention sparse selection.
|
||||
|
||||
This function implements the estimation phase of XAttention:
|
||||
1. Stride-interleaved reshape of Q and K (inverse mode)
|
||||
2. Compute block-level attention scores via Triton kernels
|
||||
3. Select important blocks based on cumulative threshold
|
||||
|
||||
The result is a boolean mask indicating which K blocks each Q block
|
||||
should attend to. This mask can be used with BSA (block_sparse_attn)
|
||||
for efficient sparse attention computation.
|
||||
|
||||
Args:
|
||||
query_states: Q tensor [batch, heads, q_len, head_dim]
|
||||
key_states: K tensor [batch, heads, k_len, head_dim]
|
||||
block_size: Block size in tokens (must be 128 for BSA compatibility)
|
||||
stride: Stride for Q/K reshape (typically 8)
|
||||
norm: Normalization factor for attention scores
|
||||
threshold: Cumulative attention threshold (0.0-1.0)
|
||||
chunk_size: Processing chunk size for memory efficiency
|
||||
use_triton: Whether to use Triton kernels (requires SM 80+)
|
||||
causal: Whether to apply causal masking
|
||||
keep_sink: Always keep first block (sink tokens)
|
||||
keep_recent: Always keep diagonal blocks (recent context)
|
||||
|
||||
Returns:
|
||||
attn_sums: Block-level attention scores [batch, heads, q_blocks, k_blocks]
|
||||
simple_masks: Boolean mask for sparse attention [batch, heads, q_blocks, k_blocks]
|
||||
|
||||
Example:
|
||||
>>> q = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.bfloat16)
|
||||
>>> k = torch.randn(1, 32, 4096, 128, device="cuda", dtype=torch.bfloat16)
|
||||
>>> attn_sums, mask = xattn_estimate(q, k, block_size=128, stride=8, threshold=0.9)
|
||||
>>> # mask can be used with block_sparse_attn_func for sparse computation
|
||||
"""
|
||||
batch_size, num_kv_head, k_len, head_dim = key_states.shape
|
||||
batch_size, num_q_head, q_len, head_dim = query_states.shape
|
||||
assert num_q_head == num_kv_head, "GQA not supported in estimation (heads must match)"
|
||||
|
||||
# Compute padding to align with chunk_size
|
||||
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
|
||||
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
|
||||
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
|
||||
k_block_num = (k_len + k_num_to_pad) // block_size
|
||||
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
|
||||
q_block_num = (q_len + q_num_to_pad) // block_size
|
||||
assert k_chunk_num >= q_chunk_num
|
||||
|
||||
# Pad K and Q if needed
|
||||
if k_num_to_pad > 0:
|
||||
pad_key_states = F.pad(key_states, (0, 0, 0, k_num_to_pad), value=0).to("cuda")
|
||||
else:
|
||||
pad_key_states = key_states
|
||||
if q_num_to_pad > 0:
|
||||
pad_query_states = F.pad(query_states, (0, 0, 0, q_num_to_pad), value=0).to("cuda")
|
||||
else:
|
||||
pad_query_states = query_states
|
||||
|
||||
# Check GPU capability for Triton
|
||||
if use_triton:
|
||||
props = torch.cuda.get_device_properties(torch.cuda.current_device())
|
||||
if props.major < 8:
|
||||
use_triton = False
|
||||
print(f"Triton kernel requires SM 80+, got SM {props.major}{props.minor}. Falling back to PyTorch.")
|
||||
|
||||
# Compute reshaped dimensions
|
||||
reshaped_chunk_size = chunk_size // stride
|
||||
reshaped_block_size = block_size // stride
|
||||
k_reshaped_num_to_pad = k_num_to_pad // stride
|
||||
k_reshaped_seq_len = (k_len + k_num_to_pad) // stride
|
||||
num_blocks_per_chunk = reshaped_chunk_size // reshaped_block_size
|
||||
|
||||
# Non-Triton fallback: explicit reshape
|
||||
if not use_triton:
|
||||
# K reshape: concat([K[:,:,k::stride,:] for k in range(stride)])
|
||||
reshaped_key = torch.cat(
|
||||
[(pad_key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1
|
||||
)
|
||||
# Q reshape (inverse): concat([Q[:,:,(stride-1-q)::stride,:] for q in range(stride)])
|
||||
reshaped_query = torch.cat(
|
||||
[(pad_query_states[:, :, (stride - 1 - q)::stride, :]) for q in range(stride)],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
attn_sum_list = []
|
||||
simple_mask_list = []
|
||||
|
||||
# Process each Q chunk
|
||||
for chunk_idx in range(q_chunk_num):
|
||||
if use_triton:
|
||||
# Triton path: fused reshape + GEMM
|
||||
attn_weights_slice = flat_group_gemm_fuse_reshape(
|
||||
pad_query_states[
|
||||
:,
|
||||
:,
|
||||
(chunk_idx * reshaped_chunk_size) * stride : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size) * stride,
|
||||
:,
|
||||
],
|
||||
pad_key_states,
|
||||
stride,
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
|
||||
is_causal=causal,
|
||||
)
|
||||
|
||||
# Fused softmax + block sum
|
||||
# Scale factor: log2(e) / sqrt(head_dim) / stride / norm
|
||||
# log2(e) ≈ 1.4426950408889634
|
||||
attn_sum = softmax_fuse_block_sum(
|
||||
attn_weights_slice,
|
||||
reshaped_block_size,
|
||||
min(4096, reshaped_block_size),
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
|
||||
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
|
||||
k_reshaped_seq_len - k_reshaped_num_to_pad,
|
||||
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
|
||||
is_causal=causal,
|
||||
)
|
||||
else:
|
||||
# PyTorch fallback path
|
||||
chunked_query = reshaped_query[
|
||||
:, :,
|
||||
chunk_idx * reshaped_chunk_size : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size),
|
||||
:,
|
||||
]
|
||||
|
||||
# Compute attention scores
|
||||
attn_weights_slice = torch.matmul(
|
||||
chunked_query, reshaped_key.transpose(2, 3)
|
||||
).to("cuda")
|
||||
attn_weights_slice = attn_weights_slice / math.sqrt(head_dim) / stride / norm
|
||||
|
||||
# Apply causal mask
|
||||
if causal:
|
||||
offset_token_chunk_num = k_chunk_num - q_chunk_num
|
||||
causal_mask = torch.zeros(
|
||||
(batch_size, num_q_head, reshaped_chunk_size, reshaped_chunk_size * k_chunk_num),
|
||||
device=key_states.device,
|
||||
)
|
||||
causal_mask[:, :, :, (-k_reshaped_num_to_pad):] = float("-inf")
|
||||
chunk_start = (chunk_idx + offset_token_chunk_num) * reshaped_chunk_size
|
||||
chunk_end = chunk_start + reshaped_chunk_size
|
||||
causal_mask[:, :, :, chunk_start:chunk_end] = torch.triu(
|
||||
torch.ones(1, num_q_head, reshaped_chunk_size, reshaped_chunk_size, device=key_states.device) * float("-inf"),
|
||||
diagonal=1,
|
||||
)
|
||||
if chunk_idx == q_chunk_num - 1 and q_num_to_pad // stride != 0:
|
||||
causal_mask[:, :, (-(q_num_to_pad // stride)):, :] = float("-inf")
|
||||
causal_mask[:, :, :, chunk_end:] = float("-inf")
|
||||
attn_weights_slice = attn_weights_slice + causal_mask
|
||||
|
||||
# Softmax
|
||||
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1, dtype=torch.float32).to(pad_query_states.dtype)
|
||||
|
||||
if chunk_idx == q_chunk_num - 1 and q_num_to_pad // stride != 0:
|
||||
attn_weights_slice[:, :, (-(q_num_to_pad // stride)):, :] = 0
|
||||
|
||||
# Block sum
|
||||
attn_sum = (
|
||||
attn_weights_slice.view(
|
||||
batch_size, num_kv_head, num_blocks_per_chunk, reshaped_block_size, -1, reshaped_block_size
|
||||
)
|
||||
.sum(dim=-1)
|
||||
.sum(dim=-2)
|
||||
.to("cuda")
|
||||
)
|
||||
|
||||
# Select blocks based on threshold
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum,
|
||||
k_block_num - q_block_num + chunk_idx * num_blocks_per_chunk,
|
||||
threshold,
|
||||
None,
|
||||
decoding=False,
|
||||
mode="prefill",
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
attn_sum_list.append(attn_sum)
|
||||
simple_mask_list.append(simple_mask)
|
||||
|
||||
del attn_weights_slice
|
||||
|
||||
if not use_triton:
|
||||
del reshaped_query, reshaped_key
|
||||
|
||||
# Concatenate results from all chunks
|
||||
attn_sums = torch.cat(attn_sum_list, dim=-2)
|
||||
simple_masks = torch.cat(simple_mask_list, dim=-2)
|
||||
|
||||
# Apply causal mask to final output
|
||||
if causal:
|
||||
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=key_states.device), diagonal=0),
|
||||
simple_masks[:, :, -q_block_num:, -q_block_num:],
|
||||
False,
|
||||
)
|
||||
|
||||
# Always keep sink block
|
||||
if keep_sink:
|
||||
simple_masks[:, :, :, 0] = True
|
||||
|
||||
# Always keep diagonal (recent) blocks
|
||||
if keep_recent:
|
||||
eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool)
|
||||
eye_matrix_expanded = eye_matrix.unsqueeze(0).unsqueeze(0).expand(1, num_kv_head, q_block_num, q_block_num)
|
||||
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||
eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:]
|
||||
)
|
||||
|
||||
return attn_sums, simple_masks
|
||||
|
||||
|
||||
def compute_sparsity(mask: torch.Tensor, causal: bool = True) -> float:
|
||||
"""
|
||||
Compute the sparsity ratio of a block mask.
|
||||
|
||||
Args:
|
||||
mask: Boolean mask [batch, heads, q_blocks, k_blocks]
|
||||
causal: Whether mask is causal (only lower triangle counts)
|
||||
|
||||
Returns:
|
||||
Sparsity ratio (0.0 = dense, 1.0 = fully sparse)
|
||||
"""
|
||||
batch, heads, q_blocks, k_blocks = mask.shape
|
||||
|
||||
if causal:
|
||||
# Only count lower triangle
|
||||
causal_mask = torch.tril(torch.ones(q_blocks, k_blocks, device=mask.device, dtype=torch.bool))
|
||||
total_blocks = causal_mask.sum().item() * batch * heads
|
||||
selected_blocks = (mask & causal_mask.unsqueeze(0).unsqueeze(0)).sum().item()
|
||||
else:
|
||||
total_blocks = mask.numel()
|
||||
selected_blocks = mask.sum().item()
|
||||
|
||||
return 1.0 - (selected_blocks / total_blocks)
|
||||
@@ -1,114 +0,0 @@
|
||||
# SparsePolicy 重构测试报告
|
||||
|
||||
## 任务概述
|
||||
|
||||
根据 task_plan.md 的要求,对 nanovllm 的 SparsePolicy 架构进行重构(v4 版本),将 chunked prefill attention 计算逻辑从 attention.py 完全迁移到 SparsePolicy。
|
||||
|
||||
## 修改范围
|
||||
|
||||
仅针对 FullPolicy,不涉及 QuestPolicy 或 XAttentionBSAPolicy,不修改 decode 阶段逻辑。
|
||||
|
||||
## 完成的修改
|
||||
|
||||
### 1. policy.py (SparsePolicy 基类)
|
||||
|
||||
- 添加 TYPE_CHECKING imports: `OffloadEngine`, `KVCacheManager`, `Sequence`
|
||||
- 修改 `select_blocks` 签名:添加 `offload_engine` 参数
|
||||
- 添加 `compute_chunked_attention` 抽象方法,参数包括:
|
||||
- `q, k, v`: 张量
|
||||
- `layer_id`: 层索引
|
||||
- `softmax_scale`: softmax 缩放因子
|
||||
- `offload_engine`: OffloadEngine 实例
|
||||
- `kvcache_manager`: KVCacheManager 实例
|
||||
- `current_chunk_idx`: 当前 chunk 索引
|
||||
- `seq`: Sequence 对象
|
||||
- `num_tokens`: 当前 chunk 的 token 数
|
||||
|
||||
### 2. full_policy.py (FullAttentionPolicy)
|
||||
|
||||
- 更新 TYPE_CHECKING imports
|
||||
- `select_blocks` 方法签名添加 `offload_engine` 参数
|
||||
- 重命名 `compute_prefill_attention` → `compute_chunked_attention`
|
||||
- 添加 `kvcache_manager` 参数,替换所有 `seq.kvcache_manager` 引用
|
||||
- 添加 debug 日志输出
|
||||
|
||||
### 3. attention.py
|
||||
|
||||
- 简化 `_chunked_prefill_attention` 方法:
|
||||
- 删除所有 `flash_attn_*` 调用
|
||||
- 删除所有 `merge_attention_outputs` 调用
|
||||
- 仅保留委托调用 `sparse_policy.compute_chunked_attention()`
|
||||
- 删除冗余方法:`_sync_load_previous_chunks`, `_ring_buffer_pipeline_load`
|
||||
- decode 路径的 `select_blocks` 调用添加 `offload_engine` 参数
|
||||
|
||||
## 验收标准检查
|
||||
|
||||
| 标准 | 状态 | 说明 |
|
||||
|------|------|------|
|
||||
| test_needle.py --enable-offload 通过 | ✅ | 测试输出 PASSED |
|
||||
| attention.py chunked prefill path 无 flash_attn_* 调用 | ✅ | `_chunked_prefill_attention` 方法(169-230行)内无直接 flash_attn 调用 |
|
||||
| attention.py chunked prefill path 无 merge_attention_outputs 调用 | ✅ | 同上 |
|
||||
| 所有 KV 通信通过 offload_engine 方法 | ✅ | 全部通过 `offload_engine.load_to_slot_layer`, `get_kv_for_slot`, `get_prefill_buffer_slice` |
|
||||
|
||||
## 测试结果
|
||||
|
||||
```
|
||||
============================================================
|
||||
Needle-in-Haystack Test
|
||||
============================================================
|
||||
Model: /home/zijie/models/Llama-3.1-8B-Instruct
|
||||
Max model len: 131072
|
||||
Input length: 8192
|
||||
Block size: 1024
|
||||
Needle position: 50%
|
||||
Needle value: 7492
|
||||
CPU offload: True
|
||||
Sparse policy: FULL
|
||||
============================================================
|
||||
|
||||
[NeedleTest] Target: 8192, Actual: 8213 tokens (diff=21)
|
||||
Expected: 7492
|
||||
Output: 7492<|eot_id|>...
|
||||
Status: PASSED
|
||||
============================================================
|
||||
|
||||
test_needle: PASSED
|
||||
```
|
||||
|
||||
## 性能指标
|
||||
|
||||
- Prefill: 3527 tok/s
|
||||
- Decode: 11 tok/s
|
||||
- TTFT: 2329.29 ms
|
||||
- TPOT: 655.38 ms
|
||||
|
||||
## 架构变更总结
|
||||
|
||||
**重构前**:
|
||||
```
|
||||
attention.py::_chunked_prefill_attention()
|
||||
├── 获取 cpu_block_table
|
||||
├── 调用 sparse_policy.select_blocks()
|
||||
├── 直接调用 flash_attn_with_lse + merge_attention_outputs
|
||||
└── 返回结果
|
||||
```
|
||||
|
||||
**重构后**:
|
||||
```
|
||||
attention.py::_chunked_prefill_attention()
|
||||
├── 获取 context 信息
|
||||
├── 调用 sparse_policy.compute_chunked_attention() # 委托全部计算
|
||||
└── 返回结果
|
||||
|
||||
sparse_policy.compute_chunked_attention() # 在 FullPolicy 中
|
||||
├── 获取 cpu_block_table
|
||||
├── 调用 self.select_blocks()
|
||||
├── 加载并计算历史 KV attention
|
||||
├── 计算当前 chunk attention (causal)
|
||||
├── 合并所有结果
|
||||
└── 返回最终输出
|
||||
```
|
||||
|
||||
## 结论
|
||||
|
||||
SparsePolicy 架构 v4 重构成功完成。所有验收标准均已满足,测试通过。
|
||||
Reference in New Issue
Block a user