feat: add xattn kernels test and update testing rules
- Add test_xattn_kernels.py demonstrating flat_group_gemm_fuse_reshape and softmax_fuse_block_sum Triton kernels with structured data - Update testing.md with new test code style guidelines - Update xattn.py and xattn_bsa.py with improvements Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,98 +1,108 @@
|
||||
# Testing
|
||||
|
||||
## Test File Guidelines
|
||||
## Test Code Style
|
||||
|
||||
### Naming Convention
|
||||
所有测试代码遵循以下风格:
|
||||
|
||||
- All test files must be named `test_*.py`
|
||||
- Example: `test_offload_engine.py`, `test_ring_buffer.py`
|
||||
|
||||
### Purpose
|
||||
|
||||
Tests are **educational scripts** for understanding module behavior, NOT traditional unit tests:
|
||||
- Focus on demonstrating how modules work
|
||||
- Show the flow and interaction between components
|
||||
- Help developers understand implementation details
|
||||
|
||||
### Code Style
|
||||
|
||||
1. **Script-based structure**: Write tests as executable scripts, not pytest-style functions
|
||||
2. **Utility functions**: Extract reusable steps as helper functions at the top of the file
|
||||
3. **Main flow as script**: The actual test/demonstration logic runs as top-level script code
|
||||
### 文件结构
|
||||
|
||||
```python
|
||||
# Example structure:
|
||||
"""
|
||||
Test: [模块名称]
|
||||
|
||||
[简要说明测试内容和数据流]
|
||||
"""
|
||||
import torch
|
||||
from nanovllm.kvcache import SomeModule
|
||||
import sys
|
||||
sys.path.insert(0, "/home/zijie/Code/nano-vllm")
|
||||
from nanovllm.xxx import xxx
|
||||
|
||||
# ============================================================
|
||||
# Utility Functions
|
||||
# 参数配置
|
||||
# ============================================================
|
||||
|
||||
def verify(tensor, expected, name):
|
||||
actual = tensor.mean().item()
|
||||
assert abs(actual - expected) < 0.01, f"{name}: {actual} != {expected}"
|
||||
param1 = value1 # 说明约束条件
|
||||
param2 = value2
|
||||
|
||||
# ============================================================
|
||||
# Main Test Script
|
||||
# 构造输入
|
||||
# ============================================================
|
||||
|
||||
# 1. Initialize
|
||||
module = SomeModule(param=value)
|
||||
input_tensor = ... # 使用结构化数据便于验证
|
||||
|
||||
# 2. Test feature X
|
||||
result = module.do_something()
|
||||
assert result == expected_value
|
||||
# ============================================================
|
||||
# Step N: [操作名称]
|
||||
# ============================================================
|
||||
|
||||
# 3. Test feature Y
|
||||
...
|
||||
output = some_function(input_tensor, ...)
|
||||
|
||||
# 验证: [验证逻辑说明]
|
||||
expected = ...
|
||||
actual = output[...].item()
|
||||
assert actual == expected, f"xxx: {actual} != {expected}"
|
||||
|
||||
print("test_xxx: PASSED")
|
||||
```
|
||||
|
||||
### Comments
|
||||
### 核心原则
|
||||
|
||||
- Keep comments concise and clear
|
||||
- Only add comments where the code isn't self-explanatory
|
||||
- Use section headers (`# === Section ===`) to organize logical blocks
|
||||
| 原则 | 说明 |
|
||||
|------|------|
|
||||
| **最小化 print** | 只在最后输出 `PASSED`,不打印中间结果 |
|
||||
| **结构化数据** | 使用可预测的输入(全 1、偶奇交替等)便于手算验证 |
|
||||
| **注释说明验证逻辑** | 在 assert 前用注释解释预期值的计算方式 |
|
||||
| **分段用 `====`** | 用 `# ============` 分隔参数、输入、各步骤 |
|
||||
| **assert 验证** | 用 assert 而不是 print 比较结果 |
|
||||
|
||||
### Output
|
||||
### 输出规范
|
||||
|
||||
- **Minimize print statements** - the code should be self-explanatory
|
||||
- Only print a final "PASSED" message at the end
|
||||
- Use `assert` for verification instead of printing results
|
||||
- If the user needs explanation, they will ask
|
||||
```python
|
||||
# ✅ 正确
|
||||
assert actual == expected, f"xxx: {actual} != {expected}"
|
||||
print("test_xxx: PASSED")
|
||||
|
||||
# ❌ 错误
|
||||
print(f"输出: {output}")
|
||||
print(f"预期: {expected}, 实际: {actual}")
|
||||
```
|
||||
|
||||
### 参数注释
|
||||
|
||||
```python
|
||||
# ✅ 正确: 注释说明约束条件
|
||||
seq_len = 512 # Triton 要求 seq_len >= stride * BLOCK_M
|
||||
segment_size = 128 # 必须 >= block_size
|
||||
|
||||
# ❌ 错误: 无意义的注释
|
||||
seq_len = 512 # 序列长度
|
||||
```
|
||||
|
||||
### 验证逻辑注释
|
||||
|
||||
```python
|
||||
# ✅ 正确: 解释计算过程
|
||||
# 验证: 反对角线求和
|
||||
# Q[奇]*K[偶] + Q[偶]*K[奇] = 2*1 + 1*2 = 4,共 stride/2 对
|
||||
expected = (2*1 + 1*2) * (stride // 2) * head_dim
|
||||
|
||||
# ❌ 错误: 只写公式不解释
|
||||
expected = 4 * 2 * 128
|
||||
```
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run a specific test
|
||||
python tests/test_offload_engine.py
|
||||
# 运行单个测试
|
||||
PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
|
||||
|
||||
# Run with specific GPU
|
||||
CUDA_VISIBLE_DEVICES=0 python tests/test_ring_buffer.py
|
||||
# 指定 GPU
|
||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_xxx.py
|
||||
```
|
||||
|
||||
## Benchmarks
|
||||
|
||||
```bash
|
||||
# Standard GPU benchmark
|
||||
python bench.py
|
||||
|
||||
# CPU offload benchmark
|
||||
python bench_offload.py
|
||||
|
||||
# vLLM comparison benchmark
|
||||
python bench_vllm.py
|
||||
```
|
||||
|
||||
## Quick Verification
|
||||
|
||||
```bash
|
||||
# Import test
|
||||
python -c "from nanovllm import LLM"
|
||||
|
||||
# Run offload benchmark (tests CPU-primary ring buffer mode)
|
||||
python bench_offload.py
|
||||
python bench.py # GPU benchmark
|
||||
python bench_offload.py # CPU offload benchmark
|
||||
python bench_vllm.py # vLLM comparison
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user