feat: add xattn kernels test and update testing rules
- Add test_xattn_kernels.py demonstrating flat_group_gemm_fuse_reshape and softmax_fuse_block_sum Triton kernels with structured data - Update testing.md with new test code style guidelines - Update xattn.py and xattn_bsa.py with improvements Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,98 +1,108 @@
|
||||
# Testing
|
||||
|
||||
## Test File Guidelines
|
||||
## Test Code Style
|
||||
|
||||
### Naming Convention
|
||||
所有测试代码遵循以下风格:
|
||||
|
||||
- All test files must be named `test_*.py`
|
||||
- Example: `test_offload_engine.py`, `test_ring_buffer.py`
|
||||
|
||||
### Purpose
|
||||
|
||||
Tests are **educational scripts** for understanding module behavior, NOT traditional unit tests:
|
||||
- Focus on demonstrating how modules work
|
||||
- Show the flow and interaction between components
|
||||
- Help developers understand implementation details
|
||||
|
||||
### Code Style
|
||||
|
||||
1. **Script-based structure**: Write tests as executable scripts, not pytest-style functions
|
||||
2. **Utility functions**: Extract reusable steps as helper functions at the top of the file
|
||||
3. **Main flow as script**: The actual test/demonstration logic runs as top-level script code
|
||||
### 文件结构
|
||||
|
||||
```python
|
||||
# Example structure:
|
||||
"""
|
||||
Test: [模块名称]
|
||||
|
||||
[简要说明测试内容和数据流]
|
||||
"""
|
||||
import torch
|
||||
from nanovllm.kvcache import SomeModule
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
from nanovllm.xxx import xxx
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
|
||||
def verify(tensor, expected, name):
|
||||
actual = tensor.mean().item()
|
||||
assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}"
|
||||
param1 = value1 # 说明约束条件
|
||||
param2 = value2
|
||||
|
||||
# ============================================================
|
||||
# Main Test Script
|
||||
# 构造输入
|
||||
# ============================================================
|
||||
|
||||
# 1. Initialize
|
||||
module = SomeModule(param=value)
|
||||
input_tensor = ... # 使用结构化数据便于验证
|
||||
|
||||
# 2. Test feature X
|
||||
result = module.do_something()
|
||||
assert result == expected_value
|
||||
# ============================================================
|
||||
# Step N: [操作名称]
|
||||
# ============================================================
|
||||
|
||||
# 3. Test feature Y
|
||||
...
|
||||
output = some_function(input_tensor, ...)
|
||||
|
||||
# 验证: [验证逻辑说明]
|
||||
expected = ...
|
||||
actual = output[...].item()
|
||||
assert actual == expected, f"xxx: {actual} != {expected}"
|
||||
|
||||
print("test_xxx: PASSED")
|
||||
```
|
||||
|
||||
### Comments
|
||||
### 核心原则
|
||||
|
||||
- Keep comments concise and clear
|
||||
- Only add comments where the code isn't self-explanatory
|
||||
- Use section headers (`# === Section ===`) to organize logical blocks
|
||||
| 原则 | 说明 |
|
||||
|------|------|
|
||||
| **最小化 print** | 只在最后输出 `PASSED`,不打印中间结果 |
|
||||
| **结构化数据** | 使用可预测的输入(全 1、偶奇交替等)便于手算验证 |
|
||||
| **注释说明验证逻辑** | 在 assert 前用注释解释预期值的计算方式 |
|
||||
| **分段用 `====`** | 用 `# ============` 分隔参数、输入、各步骤 |
|
||||
| **assert 验证** | 用 assert 而不是 print 比较结果 |
|
||||
|
||||
### Output
|
||||
### 输出规范
|
||||
|
||||
- **Minimize print statements** - the code should be self-explanatory
|
||||
- Only print a final "PASSED" message at the end
|
||||
- Use `assert` for verification instead of printing results
|
||||
- If the user needs explanation, they will ask
|
||||
```python
|
||||
# ✅ 正确
|
||||
assert actual == expected, f"xxx: {actual} != {expected}"
|
||||
print("test_xxx: PASSED")
|
||||
|
||||
# ❌ 错误
|
||||
print(f"输出: {output}")
|
||||
print(f"预期: {expected}, 实际: {actual}")
|
||||
```
|
||||
|
||||
### 参数注释
|
||||
|
||||
```python
|
||||
# ✅ 正确: 注释说明约束条件
|
||||
seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M
|
||||
segment_size = 128 # 必须 >= block_size
|
||||
|
||||
# ❌ 错误: 无意义的注释
|
||||
seq_len = 512 # 序列长度
|
||||
```
|
||||
|
||||
### 验证逻辑注释
|
||||
|
||||
```python
|
||||
# ✅ 正确: 解释计算过程
|
||||
# 验证: 反对角线求和
|
||||
# Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4,共 stride/2 对
|
||||
expected = (2*1 + 1*2) * (stride // 2) * head_dim
|
||||
|
||||
# ❌ 错误: 只写公式不解释
|
||||
expected = 4 * 2 * 128
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run a specific test
|
||||
python tests/test_offload_engine.py
|
||||
# 运行单个测试
|
||||
PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
|
||||
|
||||
# Run with specific GPU
|
||||
CUDA_VISIBLE_DEVICES=0 python tests/test_ring_buffer.py
|
||||
# 指定 GPU
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
||||
```bash
|
||||
# Standard GPU benchmark
|
||||
python bench.py
|
||||
|
||||
# CPU offload benchmark
|
||||
python bench_offload.py
|
||||
|
||||
# vLLM comparison benchmark
|
||||
python bench_vllm.py
|
||||
```
|
||||
|
||||
## Quick Verification
|
||||
|
||||
```bash
|
||||
# Import test
|
||||
python -c "from nanovllm import LLM"
|
||||
|
||||
# Run offload benchmark (tests CPU-primary ring buffer mode)
|
||||
python bench_offload.py
|
||||
python bench.py # GPU benchmark
|
||||
python bench_offload.py # CPU offload benchmark
|
||||
python bench_vllm.py # vLLM comparison
|
||||
```
|
||||
|
||||
@@ -2,69 +2,334 @@
|
||||
XAttention Block Sparse Attention (BSA) Policy for nano-vllm.
|
||||
|
||||
This module implements XAttention-inspired block sparse attention for chunked prefill.
|
||||
Current implementation loads all historical blocks (FULL strategy).
|
||||
|
||||
Sparse selection to be implemented in next phase.
|
||||
Key design:
|
||||
1. Use xattn_estimate_chunked to estimate sparse block mask
|
||||
2. Use BSA kernel for efficient sparse attention computation
|
||||
3. Support chunked prefill with q_start_pos for correct position handling
|
||||
|
||||
Note: Decode phase is not supported - use FullAttentionPolicy for decode.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import List, Tuple, TYPE_CHECKING
|
||||
|
||||
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||
from nanovllm.utils.context import get_context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from nanovllm.kvcache.offload_engine import OffloadEngine
|
||||
from nanovllm.kvcache.manager import KVCacheManager
|
||||
from nanovllm.engine.sequence import Sequence
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check BSA availability
|
||||
try:
|
||||
from block_sparse_attn import block_sparse_attn_func
|
||||
BSA_AVAILABLE = True
|
||||
except ImportError:
|
||||
BSA_AVAILABLE = False
|
||||
logger.warning("block_sparse_attn not available, XAttentionBSAPolicy will fallback to dense")
|
||||
|
||||
# Check xattn_estimate_chunked availability
|
||||
try:
|
||||
from nanovllm.ops.xattn import xattn_estimate_chunked
|
||||
XATTN_AVAILABLE = True
|
||||
except ImportError:
|
||||
XATTN_AVAILABLE = False
|
||||
logger.warning("xattn_estimate_chunked not available")
|
||||
|
||||
|
||||
def expand_kv_for_gqa(
|
||||
key_states: torch.Tensor,
|
||||
value_states: torch.Tensor,
|
||||
num_heads: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Expand KV for Grouped Query Attention.
|
||||
|
||||
Args:
|
||||
key_states: [B, num_kv_heads, seq_len, head_dim]
|
||||
value_states: [B, num_kv_heads, seq_len, head_dim]
|
||||
num_heads: Number of query heads
|
||||
|
||||
Returns:
|
||||
Expanded (key, value) with shape [B, num_heads, seq_len, head_dim]
|
||||
"""
|
||||
num_kv_heads = key_states.shape[1]
|
||||
if num_heads == num_kv_heads:
|
||||
return key_states, value_states
|
||||
num_groups = num_heads // num_kv_heads
|
||||
return (
|
||||
key_states.repeat_interleave(num_groups, dim=1),
|
||||
value_states.repeat_interleave(num_groups, dim=1),
|
||||
)
|
||||
|
||||
|
||||
class XAttentionBSAPolicy(SparsePolicy):
|
||||
"""
|
||||
XAttention Block Sparse Attention policy for chunked prefill.
|
||||
|
||||
This policy uses block-level estimation to determine which KV blocks
|
||||
are important for the current chunk's queries, enabling sparse computation.
|
||||
Uses xattn_estimate_chunked to estimate sparse mask, then BSA kernel
|
||||
for efficient sparse attention computation.
|
||||
|
||||
Note: Current implementation loads all historical chunks (FULL strategy).
|
||||
Sparse selection to be implemented in next phase.
|
||||
Note:
|
||||
- Only supports prefill phase (decode uses FullAttentionPolicy)
|
||||
- BSA block size is fixed at 128 tokens
|
||||
"""
|
||||
|
||||
supports_prefill = False # Uses standard select_blocks interface
|
||||
supports_decode = False # BSA is prefill-only
|
||||
requires_block_selection = False # Selection happens at chunk level, not block level
|
||||
supports_prefill = True
|
||||
supports_decode = False # Decode uses FullAttentionPolicy
|
||||
requires_block_selection = False # Selection happens internally
|
||||
|
||||
# BSA requires 128-token blocks
|
||||
BSA_BLOCK_SIZE = 128
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold: float = 0.9,
|
||||
stride: int = 8,
|
||||
chunk_size: int = 16384,
|
||||
block_size: int = 128,
|
||||
samples_per_chunk: int = 128,
|
||||
threshold: float = 0.9,
|
||||
use_triton: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize XAttention BSA policy.
|
||||
|
||||
Args:
|
||||
block_size: Number of tokens per block (default: 128)
|
||||
samples_per_chunk: Number of tokens to sample from each historical chunk for estimation
|
||||
threshold: Cumulative attention threshold for chunk selection (0-1)
|
||||
threshold: Cumulative attention threshold for block selection (0-1)
|
||||
Higher values = more blocks selected = less sparse
|
||||
stride: Stride for Q/K reshape in estimation (typically 8)
|
||||
chunk_size: Processing chunk size for xattn_estimate (Triton alignment)
|
||||
block_size: BSA block size (must be 128)
|
||||
samples_per_chunk: Samples per chunk for estimation (unused)
|
||||
use_triton: Whether to use Triton kernels
|
||||
"""
|
||||
self.block_size = block_size
|
||||
self.samples_per_chunk = samples_per_chunk
|
||||
self.threshold = threshold
|
||||
self.stride = stride
|
||||
self.chunk_size = chunk_size
|
||||
self.use_triton = use_triton
|
||||
self._num_heads = None # Set during first forward
|
||||
|
||||
def select_blocks(self, available_blocks: List[int], ctx: PolicyContext) -> List[int]:
|
||||
def select_blocks(
|
||||
self,
|
||||
available_blocks: List[int],
|
||||
offload_engine: "OffloadEngine",
|
||||
ctx: PolicyContext,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Select blocks to load from CPU.
|
||||
Return all blocks - actual selection happens in compute_chunked_prefill.
|
||||
"""
|
||||
return available_blocks
|
||||
|
||||
Current implementation returns all blocks (FULL strategy).
|
||||
Sparse selection to be implemented in next phase.
|
||||
def _load_all_historical_kv(
|
||||
self,
|
||||
cpu_block_table: List[int],
|
||||
layer_id: int,
|
||||
offload_engine: "OffloadEngine",
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Load all historical K/V from CPU to GPU.
|
||||
|
||||
Args:
|
||||
available_blocks: List of all available CPU block IDs
|
||||
ctx: Policy context with query info, chunk index, etc.
|
||||
cpu_block_table: List of CPU block IDs
|
||||
layer_id: Current layer index
|
||||
offload_engine: OffloadEngine instance
|
||||
|
||||
Returns:
|
||||
List of selected block IDs to load
|
||||
(k_hist, v_hist) with shape [total_tokens, kv_heads, head_dim]
|
||||
"""
|
||||
# Current: Return all blocks (FULL strategy)
|
||||
# TODO: Implement sparse selection based on query attention estimation
|
||||
return available_blocks
|
||||
if not cpu_block_table:
|
||||
return None, None
|
||||
|
||||
k_list = []
|
||||
v_list = []
|
||||
|
||||
for cpu_block_id in cpu_block_table:
|
||||
k_block, v_block = offload_engine.load_block_full_from_cpu(
|
||||
cpu_block_id, layer_id
|
||||
)
|
||||
k_list.append(k_block)
|
||||
v_list.append(v_block)
|
||||
|
||||
# Concatenate: [num_blocks, block_size, kv_heads, head_dim] -> [total_tokens, kv_heads, head_dim]
|
||||
k_hist = torch.cat(k_list, dim=0)
|
||||
v_hist = torch.cat(v_list, dim=0)
|
||||
|
||||
return k_hist, v_hist
|
||||
|
||||
def compute_chunked_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute attention for chunked prefill.
|
||||
|
||||
NOTE: The current XAttention + BSA implementation has memory issues
|
||||
(loads all historical K/V at once, losing the benefit of sparse attention).
|
||||
Until a proper ring-buffer-based sparse implementation is ready,
|
||||
we fallback to the dense attention pipeline which is memory-efficient.
|
||||
|
||||
TODO: Implement proper sparse attention with ring buffer pipeline:
|
||||
1. Use xattn_estimate_chunked to identify important blocks
|
||||
2. Only load selected blocks using ring buffer
|
||||
3. Compute sparse attention on selected blocks only
|
||||
|
||||
Args:
|
||||
q: Query tensor [seq_len, num_heads, head_dim]
|
||||
k: Key tensor [seq_len, num_kv_heads, head_dim] (current chunk)
|
||||
v: Value tensor [seq_len, num_kv_heads, head_dim] (current chunk)
|
||||
layer_id: Current layer index
|
||||
softmax_scale: Softmax scaling factor
|
||||
offload_engine: OffloadEngine for loading blocks
|
||||
kvcache_manager: KVCacheManager for block management
|
||||
current_chunk_idx: Current chunk index
|
||||
seq: Sequence object
|
||||
num_tokens: Number of tokens in current chunk
|
||||
|
||||
Returns:
|
||||
Attention output [seq_len, num_heads, head_dim]
|
||||
"""
|
||||
# Use dense fallback which is memory-efficient (ring buffer pipeline)
|
||||
# This is temporary until proper sparse implementation is ready
|
||||
return self._compute_dense_fallback(
|
||||
q, k, v, layer_id, softmax_scale, offload_engine,
|
||||
kvcache_manager, current_chunk_idx, seq, num_tokens
|
||||
)
|
||||
|
||||
def _compute_dense_fallback(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
current_chunk_idx: int,
|
||||
seq: "Sequence",
|
||||
num_tokens: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Fallback to dense attention when BSA/XAttn not available.
|
||||
Uses FullAttentionPolicy's proven pipeline.
|
||||
"""
|
||||
from nanovllm.ops.chunked_attention import flash_attn_with_lse, merge_attention_outputs
|
||||
|
||||
logger.debug(f"[XAttn] FALLBACK to dense: layer={layer_id}, chunk={current_chunk_idx}")
|
||||
|
||||
q_batched = q.unsqueeze(0) # [1, seq_len, num_heads, head_dim]
|
||||
o_acc = None
|
||||
lse_acc = None
|
||||
compute_stream = offload_engine.compute_stream
|
||||
|
||||
# Get historical CPU blocks
|
||||
cpu_block_table = kvcache_manager.get_prefilled_cpu_blocks(seq)
|
||||
|
||||
# Process historical blocks using pipeline
|
||||
if cpu_block_table:
|
||||
load_slots = list(range(offload_engine.num_ring_slots))
|
||||
num_blocks = len(cpu_block_table)
|
||||
|
||||
if len(load_slots) == 1:
|
||||
slot = load_slots[0]
|
||||
for block_idx in range(num_blocks):
|
||||
cpu_block_id = cpu_block_table[block_idx]
|
||||
offload_engine.load_to_slot_layer(slot, layer_id, cpu_block_id)
|
||||
offload_engine.wait_slot_layer(slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
offload_engine.record_slot_compute_done(slot)
|
||||
else:
|
||||
num_slots = len(load_slots)
|
||||
num_preload = min(num_slots, num_blocks)
|
||||
for i in range(num_preload):
|
||||
offload_engine.load_to_slot_layer(load_slots[i], layer_id, cpu_block_table[i])
|
||||
|
||||
for block_idx in range(num_blocks):
|
||||
current_slot = load_slots[block_idx % num_slots]
|
||||
offload_engine.wait_slot_layer(current_slot)
|
||||
|
||||
with torch.cuda.stream(compute_stream):
|
||||
prev_k, prev_v = offload_engine.get_kv_for_slot(current_slot)
|
||||
prev_o, prev_lse = flash_attn_with_lse(
|
||||
q_batched, prev_k, prev_v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
offload_engine.record_slot_compute_done(current_slot)
|
||||
|
||||
if o_acc is None:
|
||||
o_acc, lse_acc = prev_o, prev_lse
|
||||
else:
|
||||
o_acc, lse_acc = merge_attention_outputs(o_acc, lse_acc, prev_o, prev_lse)
|
||||
|
||||
next_block_idx = block_idx + num_slots
|
||||
if next_block_idx < num_blocks:
|
||||
next_slot = load_slots[next_block_idx % num_slots]
|
||||
next_cpu_block_id = cpu_block_table[next_block_idx]
|
||||
offload_engine.load_to_slot_layer(next_slot, layer_id, next_cpu_block_id)
|
||||
|
||||
# Compute attention to current chunk (causal)
|
||||
with torch.cuda.stream(compute_stream):
|
||||
k_curr, v_curr = offload_engine.get_prefill_buffer_slice(layer_id, num_tokens)
|
||||
current_o, current_lse = flash_attn_with_lse(
|
||||
q_batched, k_curr, v_curr,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# Merge historical and current
|
||||
with torch.cuda.stream(compute_stream):
|
||||
if o_acc is None:
|
||||
final_o = current_o
|
||||
else:
|
||||
final_o, _ = merge_attention_outputs(o_acc, lse_acc, current_o, current_lse)
|
||||
|
||||
torch.cuda.default_stream().wait_stream(compute_stream)
|
||||
return final_o.squeeze(0)
|
||||
|
||||
def compute_chunked_decode(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
layer_id: int,
|
||||
softmax_scale: float,
|
||||
offload_engine: "OffloadEngine",
|
||||
kvcache_manager: "KVCacheManager",
|
||||
seq: "Sequence",
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
XAttention does not support decode phase.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"XAttentionBSAPolicy does not support decode phase. "
|
||||
"Use FullAttentionPolicy for decode."
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset policy state."""
|
||||
pass
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"XAttentionBSAPolicy(threshold={self.threshold}, stride={self.stride})"
|
||||
|
||||
@@ -419,7 +419,9 @@ def flat_group_gemm_fuse_reshape(
|
||||
assert key_states.shape[1] == num_heads
|
||||
assert key_states.shape[3] == head_dim
|
||||
|
||||
output = torch.empty(
|
||||
# Use zeros instead of empty to handle causal early-exit in kernel
|
||||
# (some blocks may not be written due to causal mask optimization)
|
||||
output = torch.zeros(
|
||||
(batch_size, num_heads, q_len // stride, kv_len // stride),
|
||||
dtype=query_states.dtype,
|
||||
device=query_states.device
|
||||
@@ -1067,6 +1069,7 @@ def xattn_estimate_chunked(
|
||||
)
|
||||
|
||||
# Softmax + block sum
|
||||
# segment_size should match the standard xattn_estimate for consistency
|
||||
attn_sum = softmax_fuse_block_sum(
|
||||
attn_weights,
|
||||
reshaped_block_size,
|
||||
@@ -1082,6 +1085,14 @@ def xattn_estimate_chunked(
|
||||
attn_sum = attn_sum[:, :, :q_block_num, :k_block_num]
|
||||
else:
|
||||
# PyTorch fallback implementation
|
||||
# Match Triton kernel exactly for consistency
|
||||
#
|
||||
# Triton uses:
|
||||
# 1. exp2 (base-2 exponential) for softmax
|
||||
# 2. scale factor includes log2(e) = 1.4426950408889634
|
||||
# 3. causal mask: q_pos >= k_pos (not q_pos + 1 > k_pos)
|
||||
# 4. chunk_start for global Q position tracking
|
||||
|
||||
# Reshape K: interleave positions and concatenate head dims
|
||||
reshaped_key = torch.cat(
|
||||
[(key_states[:, :, k::stride, :]) for k in range(stride)], dim=-1
|
||||
@@ -1093,49 +1104,58 @@ def xattn_estimate_chunked(
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Use same scale as Triton: includes log2(e) for exp2 compatibility
|
||||
# Triton: scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
||||
scale = 1.4426950408889634 / math.sqrt(head_dim) / stride / norm
|
||||
|
||||
# Convert to float32 for numerical stability (matching Triton)
|
||||
reshaped_query_f32 = reshaped_query.to(torch.float32)
|
||||
reshaped_key_f32 = reshaped_key.to(torch.float32)
|
||||
|
||||
# Compute attention weights: (B, H, q_len/stride, k_len/stride)
|
||||
attn_weights = torch.matmul(
|
||||
reshaped_query, reshaped_key.transpose(2, 3)
|
||||
) / math.sqrt(head_dim) / stride / norm
|
||||
reshaped_query_f32, reshaped_key_f32.transpose(2, 3)
|
||||
) * scale
|
||||
|
||||
# Apply causal mask
|
||||
# Apply causal mask (matching Triton's logic exactly)
|
||||
if causal:
|
||||
reshaped_q_positions = reshaped_q_len
|
||||
causal_mask = torch.zeros(
|
||||
(batch_size, num_heads, reshaped_q_positions, reshaped_k_len),
|
||||
device=key_states.device,
|
||||
dtype=attn_weights.dtype,
|
||||
# Triton uses: offs_q = chunk_start + block_id * block_size + arange(0, block_size)
|
||||
# chunk_start = q_start_block * reshaped_block_size
|
||||
chunk_start = q_start_block * reshaped_block_size
|
||||
|
||||
# Create position indices in reshaped space
|
||||
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
|
||||
k_positions = torch.arange(reshaped_k_len, device=attn_weights.device)
|
||||
|
||||
# Triton causal mask: q_pos >= k_pos
|
||||
causal_mask = q_positions[:, None] >= k_positions[None, :] # (reshaped_q_len, reshaped_k_len)
|
||||
|
||||
# Apply causal mask: set future positions to -1e6 (matching Triton)
|
||||
attn_weights = attn_weights.masked_fill(
|
||||
~causal_mask.unsqueeze(0).unsqueeze(0), -1e6
|
||||
)
|
||||
|
||||
# Mask out padding in K
|
||||
if k_pad > 0:
|
||||
causal_mask[:, :, :, -(k_pad // stride):] = float("-inf")
|
||||
# Softmax using exp2 (matching Triton exactly)
|
||||
# Triton: X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None]
|
||||
# All computation in float32
|
||||
attn_max = attn_weights.max(dim=-1, keepdim=True).values
|
||||
attn_weights_shifted = attn_weights - attn_max
|
||||
attn_exp2 = torch.exp2(attn_weights_shifted)
|
||||
attn_sum_exp2 = attn_exp2.sum(dim=-1, keepdim=True)
|
||||
attn_weights = attn_exp2 / attn_sum_exp2
|
||||
|
||||
# Mask out future positions
|
||||
q_start_reshaped = q_start_pos // stride
|
||||
for q_idx in range(reshaped_q_positions):
|
||||
q_pos_reshaped = q_start_reshaped + q_idx
|
||||
if q_pos_reshaped + 1 < reshaped_k_len:
|
||||
causal_mask[:, :, q_idx, q_pos_reshaped + 1:] = float("-inf")
|
||||
# Mask for valid Q positions (matching Triton's sum_mask)
|
||||
# Triton: sum_mask = offs_q[:, None] < real_q_len
|
||||
# real_q_len = chunk_start + valid_q_reshaped
|
||||
chunk_start = q_start_block * reshaped_block_size
|
||||
real_q_len = chunk_start + valid_q_reshaped
|
||||
q_positions = torch.arange(reshaped_q_len, device=attn_weights.device) + chunk_start
|
||||
valid_q_mask = q_positions < real_q_len # (reshaped_q_len,)
|
||||
|
||||
# Handle padding in Q
|
||||
if q_pad > 0:
|
||||
q_pad_reshaped = q_pad // stride
|
||||
if q_pad_reshaped > 0:
|
||||
causal_mask[:, :, -q_pad_reshaped:, :] = float("-inf")
|
||||
# Zero out invalid Q positions
|
||||
attn_weights = attn_weights * valid_q_mask.view(1, 1, -1, 1).float()
|
||||
|
||||
attn_weights = attn_weights + causal_mask
|
||||
|
||||
# Apply softmax
|
||||
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
||||
|
||||
# Zero out padded Q positions
|
||||
if q_pad > 0:
|
||||
q_pad_reshaped = q_pad // stride
|
||||
if q_pad_reshaped > 0:
|
||||
attn_weights[:, :, -q_pad_reshaped:, :] = 0
|
||||
|
||||
# Aggregate to block level
|
||||
# Aggregate to block level (keep in float32)
|
||||
attn_sum = attn_weights.view(
|
||||
batch_size,
|
||||
num_heads,
|
||||
@@ -1145,6 +1165,9 @@ def xattn_estimate_chunked(
|
||||
reshaped_block_size,
|
||||
).sum(dim=-1).sum(dim=-2)
|
||||
|
||||
# Convert back to input dtype for consistency
|
||||
attn_sum = attn_sum.to(query_states.dtype)
|
||||
|
||||
# Find blocks that exceed threshold
|
||||
simple_mask = find_blocks_chunked(
|
||||
attn_sum,
|
||||
|
||||
86
tests/test_xattn_kernels.py
Normal file
86
tests/test_xattn_kernels.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
Test: XAttention Triton kernels
|
||||
|
||||
演示 XAttention 的两个核心 Triton kernel:
|
||||
1. flat_group_gemm_fuse_reshape: 计算 stride reshape 后的 attention scores (反对角线求和)
|
||||
2. softmax_fuse_block_sum: 对 attention scores 做 softmax 后按 block 求和
|
||||
|
||||
数据流:
|
||||
Q, K [batch, heads, seq_len, head_dim]
|
||||
↓ flat_group_gemm_fuse_reshape
|
||||
attn_scores [batch, heads, seq_len/stride, seq_len/stride]
|
||||
↓ softmax_fuse_block_sum
|
||||
block_sums [batch, heads, q_blocks, k_blocks]
|
||||
"""
|
||||
import torch
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
from nanovllm.ops.xattn import flat_group_gemm_fuse_reshape, softmax_fuse_block_sum
|
||||
|
||||
# ============================================================
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
|
||||
seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M = 4 * 128 = 512
|
||||
head_dim = 128
|
||||
stride = 4
|
||||
block_size = 128 # softmax block size (in reshaped space)
|
||||
segment_size = 128 # Triton kernel 要求 segment_size >= block_size
|
||||
|
||||
# ============================================================
|
||||
# 构造输入: 偶数位置=1, 奇数位置=2
|
||||
# ============================================================
|
||||
|
||||
Q = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
K = torch.zeros(1, 1, seq_len, head_dim, dtype=torch.bfloat16).cuda()
|
||||
for i in range(seq_len):
|
||||
if i % 2 == 0:
|
||||
Q[0, 0, i, :] = 1
|
||||
K[0, 0, i, :] = 1
|
||||
else:
|
||||
Q[0, 0, i, :] = 2
|
||||
K[0, 0, i, :] = 2
|
||||
|
||||
# ============================================================
|
||||
# Step 1: flat_group_gemm_fuse_reshape
|
||||
# ============================================================
|
||||
|
||||
attn_scores = flat_group_gemm_fuse_reshape(
|
||||
Q, K, stride,
|
||||
chunk_start=0,
|
||||
chunk_end=seq_len // stride,
|
||||
is_causal=False
|
||||
)
|
||||
|
||||
# 验证: 反对角线求和
|
||||
# 每个 stride x stride 块的反对角线: Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4
|
||||
# 反对角线有 stride/2 对,再乘以 head_dim
|
||||
expected_gemm = (2*1 + 1*2) * (stride // 2) * head_dim
|
||||
actual_gemm = attn_scores[0, 0, 0, 0].item()
|
||||
assert actual_gemm == expected_gemm, f"flat_group_gemm: {actual_gemm} != {expected_gemm}"
|
||||
|
||||
# ============================================================
|
||||
# Step 2: softmax_fuse_block_sum
|
||||
# ============================================================
|
||||
|
||||
reshaped_len = seq_len // stride
|
||||
scale = 1.4426950408889634 # log2(e) for exp2
|
||||
|
||||
block_sums = softmax_fuse_block_sum(
|
||||
attn_scores,
|
||||
block_size,
|
||||
segment_size,
|
||||
chunk_start=0,
|
||||
chunk_end=reshaped_len,
|
||||
real_q_len=reshaped_len,
|
||||
scale=scale,
|
||||
is_causal=False
|
||||
)
|
||||
|
||||
# 验证: 每个 block 的 softmax 结果求和
|
||||
# 所有 attn_scores 相同 → softmax 均匀分布 → block_sum = block_size^2 / reshaped_len
|
||||
expected_sum = block_size * block_size / reshaped_len
|
||||
actual_sum = block_sums[0, 0, 0, 0].item()
|
||||
assert actual_sum == expected_sum, f"softmax_fuse_block_sum: {actual_sum} != {expected_sum}"
|
||||
|
||||
print("test_xattn_kernels: PASSED")
|
||||
Reference in New Issue
Block a user