4 Commits

Author SHA1 Message Date
Zijie Tian
b8c00399af chore: sync submodule URL with tzj/minference (use HTTPS) 2026-01-18 19:32:18 +08:00
Zijie Tian
13586e689b docs: add chunked prefill integration plan
分析两个分支的内存布局差异,明确 Block-Based 设计对支持
任意长度推理的重要性。

核心发现:
- tzj/vs_offload 的 max_seq_len 设计导致 GPU 内存随序列长度增长
- tzj/minference 的 block-based 设计使 GPU 内存固定(~1.6 GB)
- 在 24GB RTX 3090 上可支持 4M+ tokens 推理

规划将 tzj/minference 的 chunked prefill 机制移植到 tzj/vs_offload 分支:
- Block-based GPU cache (无 layer 维度)
- Per-layer prefill buffer (完全并行 offload)
- Cross-layer pipeline buffers (double-buffering)
- Chunked prefill 流程和 LSE 在线合并

Sparse Policy 策略:保留架构,现阶段仅实现 FULL 策略

相关文件:
- docs/chunked_prefill_integration_plan.md (新增)
2026-01-18 18:49:19 +08:00
Zijie Tian
e72725c12b test: add OffloadedTensor unified test suite
Add comprehensive test suite for OffloadedTensor implementation,
including basic functionality, chunked GEMM, and sync analysis.

Components:
- OffloadedTensor: Virtual GPU tensor with transparent CPU/GPU data movement
- OffloadManager: LRU cache management with performance stats
- ChunkedOffloadLinear: Chunked GEMM along seqlen dimension

Tests (10 total):
- Basic functionality, MLP integration, LRU eviction, correctness
- Memory analysis, 128K sequence, performance comparison, transformers layer
- Sync behavior analysis, profiler analysis

Key findings:
- 93.9% memory savings for 128K sequences (3156MB → 191MB)
- Constant memory footprint regardless of sequence length
- Only 8% performance overhead from chunked processing

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-18 10:41:40 +08:00
Zijie Tian
cfb188c34a docs: add chunked prefill analysis for ultra-long sequences
Add comprehensive analysis document covering:
- MLP activation memory bottlenecks with SwiGLU architecture
- Chunked MLP strategy (98% memory reduction)
- Chunked prefill for single layers (78% memory reduction)
- Streaming Chunked Prefill (最优方案): GPU memory becomes constant
- Memory formulas and implementation guidance
- Theoretical maximum: 4M tokens on 24GB GPU (128× improvement)

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-16 10:38:02 +08:00
23 changed files with 2817 additions and 3442 deletions

View File

@@ -1,158 +0,0 @@
---
allowed-tools: Bash(CUDA_VISIBLE_DEVICES=*), Bash(PYTHONPATH=*), Bash(python*), Bash(git*), Bash(rm*), Bash(ls*), Bash(cat*), Bash(nvidia-smi*), Read, Edit, Write, Glob, Grep, TodoWrite, Task
argument-hint: --gpu <id> [--no-interrupt]
description: Execute task_plan.md refactoring with specified GPU, optionally without user interruption
---
# Execute Task Plan (exec-plan)
按照 `task_plan.md` 的要求执行代码重构,确保计划中的最终目标圆满实现。
## 参数说明
命令格式: `/exec-plan --gpu <id> [--no-interrupt]`
| 参数 | 说明 | 示例 |
|------|------|------|
| `--gpu <id>` | **必需**。指定可用的 GPU ID只能使用此 GPU 进行调试 | `--gpu 0`, `--gpu 2` |
| `--no-interrupt` | 可选。禁止中断执行,遇到问题不与用户交互,自动解决或跳过 | `--no-interrupt` |
## 当前参数
```
$ARGUMENTS
```
## 执行前准备
### 1. 解析参数
`$ARGUMENTS` 中解析:
- `GPU_ID`: 从 `--gpu <id>``-g <id>` 提取
- `NO_INTERRUPT`: 是否存在 `--no-interrupt``-n` 标志
### 2. 参数验证
**必须验证**:
- GPU_ID 必须是有效的数字
- 运行 `nvidia-smi -i <GPU_ID>` 验证 GPU 存在
### 3. 读取 task_plan.md
读取项目根目录下的 `task_plan.md` 文件,理解:
- 总体目标
- 分阶段计划 (Phase 1, 2, 3...)
- 文件修改清单
- 风险和注意事项
- 测试计划
## 执行流程
### Step 1: 创建执行计划
使用 TodoWrite 工具创建详细的执行计划,包括:
- 从 task_plan.md 提取的所有 Phase
- 每个 Phase 的子任务
- 测试验证步骤
### Step 2: 按 Phase 执行重构
对于 task_plan.md 中的每个 Phase
1. **读取当前代码**: 使用 Read/Grep 理解现有实现
2. **实施修改**: 使用 Edit/Write 进行代码修改
3. **验证修改**: 运行相关测试
### Step 3: 运行测试验证
执行 task_plan.md 中定义的测试计划,验证重构成功。
## GPU 限制规则
**严格限制**: 只能使用指定的 GPU所有涉及 GPU 的命令必须加 `CUDA_VISIBLE_DEVICES` 前缀:
```bash
# 正确
CUDA_VISIBLE_DEVICES=$GPU_ID PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python test.py
# 错误 - 禁止使用其他 GPU
python test.py # 可能使用默认 GPU 0
CUDA_VISIBLE_DEVICES=0,1 python test.py # 使用多个 GPU
```
## 中断模式规则
### 当 `--no-interrupt` 生效时
遇到以下情况**不停下来询问用户**,而是:
| 情况 | 处理方式 |
|------|----------|
| 测试失败 | 记录失败原因,尝试自动修复,继续下一步 |
| 代码冲突 | 尝试合理解决,记录解决方案 |
| 不确定的实现细节 | 选择最合理的方案继续 |
| 执行错误 | 分析错误,尝试修复,记录问题 |
**自动决策原则**:
1. 优先保证功能正确性
2. 遵循现有代码风格
3. 选择简单直接的实现
4. 记录所有自动决策到 `progress.md`
### 当未指定 `--no-interrupt` 时
遇到以下情况**可以询问用户**
- 多个实现方案需要选择
- 测试持续失败无法自动修复
- 发现 task_plan.md 中的问题或矛盾
## 执行记录
### 进度文件: progress.md
实时更新 `progress.md` 记录:
```markdown
## 执行进度
### Phase X: [名称]
- 状态: [进行中/完成/失败]
- 开始时间: [时间]
- 完成时间: [时间]
- 修改文件: [文件列表]
- 自动决策: [如果有]
- 问题记录: [如果有]
```
### 发现记录: findings.md
记录执行过程中的重要发现到 `findings.md`
## 示例用法
```bash
# 使用 GPU 2允许中断
/exec-plan --gpu 2
# 使用 GPU 0不中断执行
/exec-plan --gpu 0 --no-interrupt
# 简短形式
/exec-plan -g 1 -n
```
## 完成标准
执行完成后,确保:
1. **所有 Phase 完成**: task_plan.md 中的所有 Phase 都已实施
2. **测试通过**: task_plan.md 中的测试计划全部通过
3. **代码质量**: 修改符合项目代码规范
4. **文档更新**: progress.md 包含完整执行记录
## 重要约束
1. **GPU 隔离**: 绝对不能使用指定 GPU 以外的设备
2. **遵循计划**: 严格按照 task_plan.md 执行,不做计划外的修改
3. **渐进式修改**: 每个 Phase 完成后验证,而不是最后一起验证
4. **回滚准备**: 重大修改前考虑是否需要 git commit 保存点

6
.gitmodules vendored
View File

@@ -1,4 +1,4 @@
[submodule "3rdparty/Block-Sparse-Attention"] [submodule "3rdparty/Block-SparseAttention"]
path = 3rdparty/Block-Sparse-Attention path = 3rdparty/Block-SparseAttention
url = git@github.com:Zijie-Tian/Block-Sparse-Attention.git url = https://github.com/Zijie-Tian/Block-SparseAttention.git
branch = tzj/minference branch = tzj/minference

View File

@@ -64,6 +64,7 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
| [`docs/xattention_integration.md`](docs/xattention_integration.md) | XAttention integration guide: algorithm, implementation, design decisions, and testing | | [`docs/xattention_integration.md`](docs/xattention_integration.md) | XAttention integration guide: algorithm, implementation, design decisions, and testing |
| [`docs/xattention_analysis.md`](docs/xattention_analysis.md) | XAttention algorithm analysis: chunked estimation, block sparse attention, integration design | | [`docs/xattention_analysis.md`](docs/xattention_analysis.md) | XAttention algorithm analysis: chunked estimation, block sparse attention, integration design |
| [`docs/development_notes.md`](docs/development_notes.md) | Development notes and scratchpad for ongoing work | | [`docs/development_notes.md`](docs/development_notes.md) | Development notes and scratchpad for ongoing work |
| [`docs/chunked_prefill_analysis.md`](docs/chunked_prefill_analysis.md) | **NEW**: Chunked prefill for ultra-long sequences (1M+), memory analysis, MLP activation breakdown, implementation guide |
## Configuration ## Configuration

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,354 @@
# Chunked Prefill 集成计划
**目标**: 将 tzj/minference 分支的 chunked prefill 机制移植到 tzj/vs_offload 分支
**创建日期**: 2026-01-18
**基础分支**: `tzj/vs_offload`
**源分支**: `tzj/minference`
---
## 目标
在 tzj/vs_offload 分支上实现 chunked prefill + layerwise offload 机制,支持在 24GB RTX 3090 上运行任意长度的推理4M, 8M, 16M+ tokens
---
## 核心问题
### tzj/vs_offload 分支的局限性
当前 tzj/vs_offload 分支的 GPU ring buffer 按 `max_seq_len` 分配,导致 GPU 内存随序列长度线性增长:
```python
# 当前设计
self.layer_k_cache = torch.zeros(
num_kv_buffers, # e.g., 4
max_seq_len, # e.g., 131072 tokens
kv_heads,
head_dim,
dtype=dtype, device="cuda"
)
```
**问题**
- GPU 内存需求 ~ `max_seq_len × 4 × 8 × 128 × 2 bytes`
- 对于超长序列不可行:
- 4M tokens → ~64 GB GPU 内存 ❌
- 8M tokens → ~128 GB GPU 内存 ❌
### 解决方案Block-Based 设计
tzj/minference 分支采用 block-based 设计GPU 内存固定:
```python
# Block-based 设计
self.k_cache_gpu = torch.zeros(
num_gpu_blocks, # e.g., 2
block_size, # e.g., 1024 tokens (固定!)
kv_heads,
head_dim,
dtype=dtype, device="cuda"
)
# GPU 内存: ~4 MB (固定,不随序列长度增长)
```
**优势**
- GPU 内存固定(~1.6 GB不随序列长度增长
- 24GB RTX 3090 可运行 4M+ tokens
- 通过 chunked prefill 分块处理超长序列
---
## 内存布局对比
| 组件 | tzj/vs_offload | tzj/minference | 说明 |
|------|---------------|----------------|------|
| **GPU Ring Buffer** | `[num_kv_buffers, max_seq_len, ...]` | `[num_gpu_blocks, block_size, ...]` | minference 无 layer 维度 |
| **GPU 内存** | ~2.15 GB (128K) → ~64 GB (4M) | ~4 MB (固定) | minference 节省显著 |
| **Prefill Buffer** | ❌ 无 | ✅ `[num_layers, block_size, ...]` | minference 独有 |
| **Pipeline Buffers** | ❌ 无 | ✅ 双缓冲区 `[blocks, block_size, ...]` | minference 独有 |
| **CPU Cache** | `[num_layers, num_cpu_blocks, block_size, ...]` | 相同 | **一致** |
### 序列长度支持对比
| 序列长度 | vs_offload GPU 内存 | minference GPU 内存 | RTX 3090 (24GB) |
|----------|-------------------|---------------------|-----------------|
| 128K tokens | ~2.15 GB | ~4 MB | ✅ 两者均可 |
| 1M tokens | ~16 GB | ~4 MB | ✅ 两者均可 |
| **4M tokens** | **~64 GB** ❌ | **~4 MB** ✅ | **仅 minference 可行** |
| **8M tokens** | **~128 GB** ❌ | **~4 MB** ✅ | **仅 minference 可行** |
| **16M+ tokens** | **~256 GB+** ❌ | **~4 MB** ✅ | **仅 minference 可行** |
---
## 关键设计原则
1. **Block-Based 设计**:按 `block_size` (1024 tokens) 组织,支持 chunked prefill
2. **GPU 内存固定**:不随序列长度增长,是 constant factor
3. **CPU 内存线性缩放**`num_cpu_blocks = ceil(seq_len / block_size)`
4. **Unified Ring Buffer**:无 layer 维度,所有层共享 slots
5. **完全并行 offload**per-layer buffer 最大化 PCIe 带宽
---
## 统一内存布局设计
### GPU Memory Layout
```python
class OffloadEngine:
# 1. Unified Ring Buffer - Block-based无 layer 维度
self.k_cache_gpu = torch.zeros(
num_gpu_blocks, # e.g., 2
block_size, # e.g., 1024
kv_heads,
head_dim,
dtype=dtype, device="cuda"
) # ~4 MB (固定)
# 2. Per-layer Prefill Buffer - 完全并行 offload
self.prefill_k_buffer = torch.zeros(
num_layers, block_size, kv_heads, head_dim,
dtype=dtype, device="cuda"
) # ~58 MB (固定)
# 3. Cross-layer Pipeline Buffers - Double-buffering
self.layer_k_buffer_a = torch.zeros(
max_prefill_blocks, block_size, kv_heads, head_dim,
dtype=dtype, device="cuda"
) # ~512 MB (固定)
self.layer_k_buffer_b = torch.zeros(...) # ~512 MB (固定)
# 4. Per-layer Decode Buffer
self.decode_k_buffer = torch.zeros(
num_layers, block_size, kv_heads, head_dim,
dtype=dtype, device="cuda"
) # ~58 MB (固定)
# GPU 总计:~1.6 GB (固定,不随序列长度增长)
```
### CPU Memory Layout
```python
# CPU Cache - 有 block 维度
self.k_cache_cpu = torch.zeros(
num_layers,
num_cpu_blocks, # 随序列长度缩放
block_size,
kv_heads,
head_dim,
dtype=dtype, device="cpu", pin_memory=True
)
# 128K tokens: ~2.9 GB
# 1M tokens: ~5.8 GB
# 4M tokens: ~23.3 GB
```
---
## Chunked Prefill 流程
### Prefill 阶段
```
For each chunk:
├── 1. Prepare chunk input (block_size tokens)
├── 2. Get ring buffer slot: slot = chunk_idx % num_gpu_blocks
├── 3. Load previous KV chunks to ring slots[1..N-1]
├── 4. Model Forward (all layers)
│ For each layer:
│ ├── Load previous KV from ring slots
│ ├── Compute attention (current chunk + previous)
│ ├── Write KV to prefill_buffer[layer_id] ← Per-layer!
│ └── Async offload to CPU (parallel across layers)
├── 5. Merge attention outputs (LSE)
└── 6. Record compute done for slot
Key: Per-layer prefill buffer → Layer 0 offload || Layer 1 compute || Layer 2 load ...
```
### Decode 阶段
```
├── 1. Setup pipeline: preload Layer 0 to buffer_a
├── 2. For each layer:
│ ├── Get KV from pipeline buffer (a or b)
│ ├── Trigger preload of next layer to other buffer
│ ├── Compute attention
│ └── Store to decode buffer
└── 3. End pipeline
Key: Double-buffering → Layer N compute || Layer N+1 load
```
---
## 合并策略
### 基础分支选择tzj/vs_offload
**原因**
1. 更完善的文档系统
2. 更完整的 sparse attention 实现QUEST, XAttention 等)
3. 更清晰的代码组织和注释
4. 更活跃的开发维护
### 移植策略
**从 tzj/minference 移植**
1. GPU cache 内存布局(无 layer 维度block-based
2. Per-layer prefill buffer
3. Cross-layer pipeline buffers
4. Chunked prefill 流程
5. LSE 在线合并机制
**保留 tzj/vs_offload 优势**
1. 文档系统
2. Sparse policy 架构
3. 代码组织和注释
---
## Sparse Policy 策略
**策略**:保留架构,现阶段仅实现 FULL
- **保留** sparse policy 的架构设计和接口
- **预留** 扩展接口给未来的 QUEST 等其他策略
- **现阶段仅实现** FULL 策略,确保正确性和稳定性
### 实现
```python
class SparsePolicy(ABC):
@property
def supports_prefill(self) -> bool:
return False
@property
def supports_decode(self) -> bool:
return True
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
"""预留给未来策略(如 QUEST 收集元数据)"""
pass
def select_blocks(self, available_blocks, context) -> List[int]:
"""FULL: 返回所有可用块"""
return available_blocks
class FullAttentionPolicy(SparsePolicy):
@property
def supports_prefill(self) -> bool:
return True
@property
def supports_decode(self) -> bool:
return True
```
---
## 关键 API
### Ring Buffer 管理
```python
# Prefill 阶段
get_write_slot_for_prefill(chunk_idx) -> slot_idx
get_load_slots_for_prefill(write_slot_idx) -> [slot_ids]
# Decode 阶段
get_load_slots_for_decode() -> [slot_ids] (excludes decode_slot)
```
### Per-layer 操作
```python
# 加载
load_to_slot_layer(slot_idx, layer_id, cpu_block_id)
wait_slot_layer(slot_idx)
# Prefill buffer
get_prefill_buffer(layer_id) -> (k, v)
offload_prefill_buffer_async(layer_id, cpu_block_id, num_tokens)
wait_prefill_offload(layer_id)
# Pipeline
start_decode_pipeline(cpu_block_ids)
get_decode_layer_kv(layer_id, num_blocks) -> (k, v)
end_decode_pipeline()
```
---
## 实施阶段
### Phase 1: 内存布局重构
- 修改 GPU cache 为 unified ring buffer
- 添加 per-layer prefill buffer
- 添加 cross-layer pipeline buffers
### Phase 2: API 实现
- 实现 ring buffer slot 管理 API
- 实现 per-layer prefill offload API
- 实现 cross-layer pipeline API
### Phase 3: 集成到 Attention Layer
- 修改 attention forward 流程
- 集成 per-layer prefill buffer
- 集成 cross-layer pipeline
### Phase 4: 集成到 Model Runner
- 实现 chunked prefill 流程
- 集成 LSE 合并
- 优化流水线
### Phase 5: Sparse Policy 集成FULL
- 设计统一的策略接口
- 实现 FullAttentionPolicy
- 预留 QUEST 等未来策略的扩展接口
---
## 关键决策
1. **Block-Based 设计优先**:支持任意长度推理的核心
2. **采用 tzj/minference 的内存布局**GPU cache 无 layer 维度 + block-based
3. **以 tzj/vs_offload 为基础分支**:更好的文档和代码组织
4. **分阶段合并策略**:降低复杂度,便于验证
5. **Sparse Policy - FULL 优先**:保留架构,现阶段仅实现 FULL
---
## 预期结果
### 内存使用28层模型block_size=1024
| 组件 | 内存 |
|------|------|
| GPU Unified Ring Buffer | ~4 MB |
| GPU Per-layer Prefill Buffer | ~58 MB |
| GPU Pipeline Buffers (×2) | ~1 GB |
| GPU Decode Buffer | ~58 MB |
| **GPU 总计** | **~1.6 GB (固定)** |
| CPU Cache (4M tokens) | ~23.3 GB |
| **总计 (4M tokens)** | **~24.9 GB** ✅ 适配 24GB RTX 3090 |
### 性能支持
- ✅ 支持 4M, 8M, 16M+ tokens 的推理
- ✅ GPU 内存固定,不随序列长度增长
- ✅ 完全并行的 layerwise offload
- ✅ Cross-layer 流水线优化
---
## 参考
- **OffloadEngine**: `nanovllm/kvcache/offload_engine.py`
- **Attention Layer**: `nanovllm/layers/attention.py`
- **Model Runner**: `nanovllm/engine/model_runner.py`
- **Sparse Policy**: `nanovllm/kvcache/sparse/policy.py`

View File

@@ -62,7 +62,6 @@ class Config:
xattn_keep_sink: bool = False # Always keep first block (sink tokens) xattn_keep_sink: bool = False # Always keep first block (sink tokens)
xattn_keep_recent: bool = False # Always keep recent diagonal blocks xattn_keep_recent: bool = False # Always keep recent diagonal blocks
xattn_norm: float = 1.0 # Normalization factor for attention scores xattn_norm: float = 1.0 # Normalization factor for attention scores
xattn_use_bsa: bool = True # Use Block Sparse Attention library (requires installation)
def __post_init__(self): def __post_init__(self):
assert os.path.isdir(self.model) assert os.path.isdir(self.model)

View File

@@ -57,8 +57,8 @@ class ModelRunner:
load_model(self.model, config.model) load_model(self.model, config.model)
self.sampler = GreedySampler() self.sampler = GreedySampler()
# Initialize attention_policy before warmup (will be configured in allocate_kv_cache) # Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache)
self.attention_policy = None self.sparse_prefill_policy = None
#> Disable warmup for debugging #> Disable warmup for debugging
self.warmup_model() self.warmup_model()
@@ -178,9 +178,11 @@ class ModelRunner:
# Create KV cache manager using factory # Create KV cache manager using factory
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config) self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
# Create attention policy (always, including FULL) # Create sparse prefill policy
# In layerwise offload mode, all attention goes through the policy # This is used for both GPU-only and CPU offload modes when policy supports prefill
from nanovllm.kvcache.sparse import create_attention_policy self.sparse_prefill_policy = None
if config.sparse_policy != SparsePolicyType.FULL:
from nanovllm.kvcache.sparse import create_sparse_policy
# Get policy-specific parameters based on type # Get policy-specific parameters based on type
if config.sparse_policy == SparsePolicyType.XATTN: if config.sparse_policy == SparsePolicyType.XATTN:
@@ -192,9 +194,8 @@ class ModelRunner:
"keep_sink": config.xattn_keep_sink, "keep_sink": config.xattn_keep_sink,
"keep_recent": config.xattn_keep_recent, "keep_recent": config.xattn_keep_recent,
"norm": config.xattn_norm, "norm": config.xattn_norm,
"use_bsa": config.xattn_use_bsa,
} }
elif config.sparse_policy == SparsePolicyType.MINFERENCE: else: # MINFERENCE or others
policy_kwargs = { policy_kwargs = {
"vertical_size": config.minference_vertical_size, "vertical_size": config.minference_vertical_size,
"slash_size": config.minference_slash_size, "slash_size": config.minference_slash_size,
@@ -202,11 +203,13 @@ class ModelRunner:
"num_sink_tokens": config.minference_num_sink_tokens, "num_sink_tokens": config.minference_num_sink_tokens,
"num_recent_diags": config.minference_num_recent_diags, "num_recent_diags": config.minference_num_recent_diags,
} }
else: # FULL or QUEST
policy_kwargs = {}
self.attention_policy = create_attention_policy(config.sparse_policy, **policy_kwargs) policy = create_sparse_policy(config.sparse_policy, **policy_kwargs)
logger.info(f"Attention policy: {self.attention_policy}")
# Only use if policy supports sparse prefill
if policy.supports_prefill:
self.sparse_prefill_policy = policy
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}")
# Allocate cache through manager # Allocate cache through manager
self.kvcache_manager.allocate_cache( self.kvcache_manager.allocate_cache(
@@ -392,7 +395,7 @@ class ModelRunner:
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
slot_mapping, None, block_tables, slot_mapping, None, block_tables,
attention_policy=self.attention_policy) sparse_prefill_policy=self.sparse_prefill_policy)
return input_ids, positions return input_ids, positions
def prepare_decode(self, seqs: list[Sequence]): def prepare_decode(self, seqs: list[Sequence]):
@@ -589,10 +592,20 @@ class ModelRunner:
# RoPE # RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k) q, k = layer.self_attn.rotary_emb(positions, q, k)
# Compute attention using policy (uses k, v directly - before store!) # Sparse or Full attention (uses k, v directly - before store!)
attn_output = self.attention_policy.compute_prefill( if self.sparse_prefill_policy is not None:
q, k, v, layer_id, attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale, softmax_scale=layer.self_attn.attn.scale,
causal=True,
) )
# O projection # O projection
@@ -859,10 +872,22 @@ class ModelRunner:
# RoPE # RoPE
q, k = layer.self_attn.rotary_emb(positions, q, k) q, k = layer.self_attn.rotary_emb(positions, q, k)
# Compute attention using policy # Sparse or Full attention
attn_output = self.attention_policy.compute_prefill( if self.sparse_prefill_policy is not None:
q, k, v, layer_id, # MInference or other sparse prefill policy
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
# Full attention using FlashAttention
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=total_tokens,
max_seqlen_k=total_tokens,
softmax_scale=layer.self_attn.attn.scale, softmax_scale=layer.self_attn.attn.scale,
causal=True,
) )
# O projection # O projection

View File

@@ -1,56 +1,49 @@
""" """
Attention Policy module for layerwise offload mode. Sparse Attention Policy module.
Provides pluggable policies for attention computation: Provides pluggable policies for selecting which KV blocks to load
- FullAttentionPolicy: Standard FlashAttention (no sparsity) during chunked attention with CPU offload.
- XAttentionPolicy: Sparse prefill using XAttention algorithm
- MInferencePolicy: MInference sparse attention
- QuestPolicy: Quest block selection (for chunked offload)
Usage: Usage:
from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
# Create policy using factory function # Create policy using factory function
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9) policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
# Use policy for attention
attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
# Or create custom policy # Or create custom policy
class MyPolicy(AttentionPolicy): class MyPolicy(SparsePolicy):
supports_prefill = True supports_prefill = True
supports_decode = True supports_decode = True
def compute_prefill(self, q, k, v, layer_id, softmax_scale): def select_blocks(self, available_blocks, ctx):
# Custom attention computation return available_blocks[:5] # Just first 5 blocks
...
""" """
from nanovllm.config import SparsePolicyType from nanovllm.config import SparsePolicyType
from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
from nanovllm.kvcache.sparse.minference import MInferencePolicy from nanovllm.kvcache.sparse.minference import MInferencePolicy
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy: def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
""" """
Create an attention policy instance from an enum type. Create a sparse policy instance from an enum type.
All attention (including full attention) goes through a policy in layerwise The returned policy is not yet initialized. Call policy.initialize()
offload mode. The policy is responsible for computing prefill/decode attention. or let the framework call it during KV cache allocation.
Args: Args:
policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST) policy_type: SparsePolicyType enum value
**kwargs: Policy-specific configuration options **kwargs: Policy-specific configuration options
Returns: Returns:
AttentionPolicy instance SparsePolicy instance (not initialized)
Example: Example:
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9) policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale) policy.initialize(num_layers=28, num_kv_heads=8, ...)
""" """
if policy_type == SparsePolicyType.FULL: if policy_type == SparsePolicyType.FULL:
return FullAttentionPolicy() return FullAttentionPolicy()
@@ -82,32 +75,21 @@ def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> Attentio
keep_sink=kwargs.get("keep_sink", False), keep_sink=kwargs.get("keep_sink", False),
keep_recent=kwargs.get("keep_recent", False), keep_recent=kwargs.get("keep_recent", False),
norm=kwargs.get("norm", 1.0), norm=kwargs.get("norm", 1.0),
use_bsa=kwargs.get("use_bsa", True),
) )
else: else:
raise ValueError(f"Unknown policy type: {policy_type}") raise ValueError(f"Unknown policy type: {policy_type}")
# Backward compatibility alias
create_sparse_policy = create_attention_policy
__all__ = [ __all__ = [
# New interface
"AttentionPolicy",
"create_attention_policy",
# Backward compatibility
"SparsePolicy", "SparsePolicy",
"create_sparse_policy",
# Common types
"PolicyContext", "PolicyContext",
"SparsePolicyType", "SparsePolicyType",
# Policy implementations
"FullAttentionPolicy", "FullAttentionPolicy",
"QuestPolicy", "QuestPolicy",
"QuestConfig", "QuestConfig",
"BlockMetadataManager", "BlockMetadataManager",
"MInferencePolicy", "MInferencePolicy",
"XAttentionPolicy", "XAttentionPolicy",
"create_sparse_policy",
] ]

View File

@@ -1,21 +1,20 @@
""" """
Full attention policy - standard FlashAttention without sparsity. Full attention policy - loads all blocks (no sparsity).
This serves as a baseline and default policy when sparse This serves as a baseline and default policy when sparse
attention is not needed. attention is not needed.
""" """
from typing import Optional from typing import List
import torch from .policy import SparsePolicy, PolicyContext
from .policy import AttentionPolicy
class FullAttentionPolicy(AttentionPolicy): class FullAttentionPolicy(SparsePolicy):
""" """
Full attention policy using FlashAttention (no sparsity). Full attention policy that loads all available blocks.
This is the default behavior with standard causal attention. This is the default behavior with no sparsity - all previous
All tokens attend to all previous tokens. KV cache blocks are loaded for each query chunk.
Use this as: Use this as:
- A baseline for comparing sparse policies - A baseline for comparing sparse policies
@@ -26,55 +25,15 @@ class FullAttentionPolicy(AttentionPolicy):
# Full attention supports both prefill and decode # Full attention supports both prefill and decode
supports_prefill = True supports_prefill = True
supports_decode = True supports_decode = True
requires_block_selection = False # Load all blocks, no selective loading
def estimate( def select_blocks(
self, self,
q: torch.Tensor, available_blocks: List[int],
k: torch.Tensor, ctx: PolicyContext,
layer_id: int, ) -> List[int]:
) -> Optional[torch.Tensor]: """Return all blocks - no sparsity."""
""" return available_blocks
Full attention - no sparse mask needed.
Returns None to indicate full attention should be used.
"""
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute full causal attention using FlashAttention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def __repr__(self) -> str: def __repr__(self) -> str:
return "FullAttentionPolicy()" return "FullAttentionPolicy()"

View File

@@ -10,10 +10,10 @@ from typing import List, Tuple, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import AttentionPolicy, PolicyContext from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
class MInferencePolicy(AttentionPolicy): class MInferencePolicy(SparsePolicy):
""" """
MInference sparse prefill policy using vertical + slash pattern. MInference sparse prefill policy using vertical + slash pattern.
@@ -347,33 +347,6 @@ class MInferencePolicy(AttentionPolicy):
return o return o
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute MInference sparse prefill attention.
This is the new unified interface for attention policies.
Delegates to sparse_prefill_attention (ignores softmax_scale as MInference
computes it internally from head_dim).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (unused, computed internally)
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
return self.sparse_prefill_attention(q, k, v, layer_id)
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"MInferencePolicy(" return (f"MInferencePolicy("
f"adaptive_budget={self.adaptive_budget}, " f"adaptive_budget={self.adaptive_budget}, "

View File

@@ -1,18 +1,13 @@
""" """
Base class for attention policies in layerwise offload mode. Base class for sparse attention policies.
AttentionPolicy defines the interface for all attention computation, Sparse attention policies determine which KV cache blocks to load
including full attention and sparse attention methods like XAttention. from CPU for each query chunk during chunked attention computation.
Key methods:
- estimate(): Compute sparse attention mask (optional, returns None for full attention)
- compute_prefill(): Compute prefill attention
- compute_decode(): Compute decode attention (default implementation provided)
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Any
import torch import torch
# Import SparsePolicyType from config to avoid circular imports # Import SparsePolicyType from config to avoid circular imports
@@ -22,10 +17,10 @@ from nanovllm.config import SparsePolicyType
@dataclass @dataclass
class PolicyContext: class PolicyContext:
""" """
Context passed to attention policy for block selection. Context passed to sparse policy for block selection.
This dataclass contains all information needed by an attention policy This dataclass contains all information needed by a sparse policy
for sparse estimation and attention computation. to decide which blocks to load for the current query chunk.
""" """
query_chunk_idx: int query_chunk_idx: int
@@ -54,41 +49,40 @@ class PolicyContext:
"""Total KV sequence length so far (for reference).""" """Total KV sequence length so far (for reference)."""
class AttentionPolicy(ABC): class SparsePolicy(ABC):
""" """
Base class for attention policies in layerwise offload mode. Abstract base class for sparse attention policies.
All attention computation goes through a policy, including both Subclass this and implement select_blocks() to create custom
full attention and sparse attention methods. sparse attention patterns. The policy receives context about
the current query chunk and returns which KV blocks to load.
The policy interface is designed for layerwise offload where:
- The entire KV cache for a layer is on GPU during computation
- No need for block loading from CPU during attention
- estimate() returns a sparse mask (or None for full attention)
- compute_prefill()/compute_decode() perform the actual attention
Attributes: Attributes:
supports_prefill: Whether this policy can be used for prefill phase. supports_prefill: Whether this policy can be used for prefill phase.
supports_decode: Whether this policy can be used for decode phase. supports_decode: Whether this policy can be used for decode phase.
Example: Example:
class MyPolicy(AttentionPolicy): class MySparsePolicy(SparsePolicy):
supports_prefill = True supports_prefill = False # decode-only policy
supports_decode = True supports_decode = True
def estimate(self, q, k, layer_id): def select_blocks(self, available_blocks, ctx):
# Return sparse mask or None # Load first block and last 2 blocks
return None if len(available_blocks) <= 3:
return available_blocks
def compute_prefill(self, q, k, v, layer_id, softmax_scale): return [available_blocks[0]] + available_blocks[-2:]
# Compute attention
return flash_attn_varlen_func(q, k, v, ...)
""" """
# Compatibility flags - override in subclasses # Compatibility flags - override in subclasses
supports_prefill: bool = True supports_prefill: bool = True
supports_decode: bool = True supports_decode: bool = True
# Whether this policy requires selective block loading during decode
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
# Example: MInference=False (only affects attention), Quest=True (affects load)
requires_block_selection: bool = False
def initialize( def initialize(
self, self,
num_layers: int, num_layers: int,
@@ -102,7 +96,7 @@ class AttentionPolicy(ABC):
Initialize policy resources. Initialize policy resources.
Called by the framework after KV cache is allocated. Override this Called by the framework after KV cache is allocated. Override this
to create metadata structures or pre-allocate buffers. to create metadata structures (e.g., BlockMetadataManager for Quest).
Default implementation does nothing. Default implementation does nothing.
Args: Args:
@@ -115,98 +109,76 @@ class AttentionPolicy(ABC):
""" """
pass pass
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Estimate sparse attention mask.
For sparse policies (e.g., XAttention), computes block-level importance
and returns a boolean mask indicating which blocks to attend.
For full attention policy, returns None.
This corresponds to xattn_estimate() in COMPASS.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
or None for full attention
"""
return None
@abstractmethod @abstractmethod
def compute_prefill( def select_blocks(
self, self,
q: torch.Tensor, available_blocks: List[int],
k: torch.Tensor, ctx: PolicyContext,
v: torch.Tensor, ) -> List[int]:
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
""" """
Compute prefill attention. Select which KV blocks to load for the current query chunk.
The entire KV cache for this layer is on GPU. Compute attention This is the core method that defines the sparse attention pattern.
between Q and K/V, optionally using sparse mask from estimate(). The returned blocks will be loaded from CPU to GPU for attention
computation against the current query chunk.
Args: Args:
q: Query tensor [seq_len, num_heads, head_dim] available_blocks: List of CPU block IDs that contain KV cache
k: Key tensor [seq_len, num_kv_heads, head_dim] from previous chunks. These are ordered by
v: Value tensor [seq_len, num_kv_heads, head_dim] their position in the sequence.
layer_id: Transformer layer index ctx: PolicyContext with information about the current query
softmax_scale: Softmax scaling factor (1/sqrt(head_dim)) chunk, layer, phase (prefill/decode), etc.
Returns: Returns:
Attention output [seq_len, num_heads, head_dim] List of block IDs to load (must be a subset of available_blocks).
The order may affect performance (sequential access is faster).
Returning [] means no previous blocks will be loaded.
""" """
pass pass
def compute_decode( def on_prefill_offload(
self, self,
q: torch.Tensor, cpu_block_id: int,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int, layer_id: int,
softmax_scale: float, k_cache: torch.Tensor,
) -> torch.Tensor: num_valid_tokens: int,
) -> None:
""" """
Compute decode attention. Hook called when a block is offloaded during prefill phase.
KV is provided from ring buffer, containing prefill tokens + decoded tokens. Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
Default implementation uses FlashAttention. Override this to collect metadata about blocks (e.g., min/max keys
for Quest-style selection). Default implementation does nothing.
Args: Args:
q: Query tensor [1, num_heads, head_dim] cpu_block_id: The CPU block ID that will be written
k: Key tensor [context_len+1, num_kv_heads, head_dim]
v: Value tensor [context_len+1, num_kv_heads, head_dim]
layer_id: Transformer layer index layer_id: Transformer layer index
softmax_scale: Softmax scaling factor k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
num_valid_tokens: Number of valid tokens in this block
Returns:
Attention output [1, num_heads, head_dim]
""" """
from flash_attn.flash_attn_interface import flash_attn_varlen_func pass
context_len = k.shape[0] def on_decode_offload(
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device) self,
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device) cpu_block_id: int,
layer_id: int,
k_cache: torch.Tensor,
num_valid_tokens: int,
) -> None:
"""
Hook called when a block is offloaded during decode phase.
return flash_attn_varlen_func( Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
q, k, v, Override this to update metadata about blocks. Default implementation
cu_seqlens_q=cu_seqlens_q, does nothing.
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1, Args:
max_seqlen_k=context_len, cpu_block_id: The CPU block ID that will be written
softmax_scale=softmax_scale, layer_id: Transformer layer index
causal=False, k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
) num_valid_tokens: Number of valid tokens in this block
"""
pass
def reset(self) -> None: def reset(self) -> None:
""" """
@@ -217,9 +189,32 @@ class AttentionPolicy(ABC):
""" """
pass pass
def sparse_prefill_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
) -> torch.Tensor:
"""
Compute sparse attention for prefill phase.
This method is called when supports_prefill=True and the policy
is used for GPU-only sparse prefill (no CPU offload).
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Current transformer layer index
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
"Set supports_prefill=False or implement this method."
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"{self.__class__.__name__}()" return f"{self.__class__.__name__}()"
# Backward compatibility alias
SparsePolicy = AttentionPolicy

View File

@@ -11,7 +11,7 @@ import logging
import torch import torch
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Tuple, Optional from typing import List, Tuple, Optional
from .policy import AttentionPolicy, PolicyContext from .policy import SparsePolicy, PolicyContext
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -137,7 +137,7 @@ class QuestConfig:
"""Always include this many recent blocks (last N blocks), in addition to Top-K.""" """Always include this many recent blocks (last N blocks), in addition to Top-K."""
class QuestPolicy(AttentionPolicy): class QuestPolicy(SparsePolicy):
""" """
Quest-style Top-K block selection using min/max key bounds. Quest-style Top-K block selection using min/max key bounds.
@@ -317,25 +317,6 @@ class QuestPolicy(AttentionPolicy):
if self.metadata is not None: if self.metadata is not None:
self.metadata.reset() self.metadata.reset()
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Quest does not support prefill - raises error.
Quest is a decode-only policy for selective block loading.
For prefill, use FullAttentionPolicy or XAttentionPolicy.
"""
raise NotImplementedError(
"QuestPolicy does not support prefill. "
"Use FullAttentionPolicy or XAttentionPolicy for prefill."
)
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
f"QuestPolicy(topk={self.config.topk_blocks}, " f"QuestPolicy(topk={self.config.topk_blocks}, "

View File

@@ -4,56 +4,48 @@ XAttention sparse attention policy for nano-vllm.
Implements the XAttention algorithm from COMPASS, using chunked estimation Implements the XAttention algorithm from COMPASS, using chunked estimation
and block sparse attention for efficient long-context inference. and block sparse attention for efficient long-context inference.
Architecture:
XAttention = Estimate (Triton) + Compute (BSA)
- Estimate: xattn_estimate() computes block-level importance scores
- Compute: block_sparse_attn_func() executes sparse attention
Reference: COMPASS/compass/src/Xattention.py Reference: COMPASS/compass/src/Xattention.py
""" """
import math import math
from typing import Optional from typing import List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from nanovllm.kvcache.sparse.policy import AttentionPolicy from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
from nanovllm.kvcache.sparse.kernels import (
# BSA block size is fixed at 128 (hardcoded in block_sparse_attn) flat_group_gemm_fuse_reshape,
BSA_BLOCK_SIZE = 128 softmax_fuse_block_sum,
)
from nanovllm.kvcache.sparse.utils import find_blocks_chunked
class XAttentionPolicy(AttentionPolicy): class XAttentionPolicy(SparsePolicy):
""" """
XAttention sparse prefill policy using chunked estimation + block sparse attention. XAttention sparse prefill policy using chunked estimation + block sparse attention.
This policy estimates sparse attention patterns by: This policy estimates sparse attention patterns by:
1. Chunked QK computation using Triton kernels (via nanovllm.ops.xattn) 1. Chunked QK computation using Triton kernels
2. Block-wise softmax with importance scores 2. Block-wise softmax with importance scores
3. Block selection based on threshold 3. Block selection based on threshold
4. Block sparse attention computation using MIT-HAN-LAB BSA library 4. Block sparse attention computation
The key method is estimate() which calls xattn_estimate() from nanovllm.ops
to compute the sparse attention mask.
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.) Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
BSA library: https://github.com/mit-han-lab/Block-Sparse-Attention
""" """
supports_prefill = True supports_prefill = True
supports_decode = True # Uses default FlashAttention for decode supports_decode = False # XAttention is prefill-only
requires_block_selection = False # Only affects attention computation
def __init__( def __init__(
self, self,
stride: int = 8, stride: int = 8,
threshold: float = 0.9, threshold: float = 0.9,
block_size: int = 128, chunk_size: Optional[int] = None,
chunk_size: int = 16384,
use_triton: bool = True, use_triton: bool = True,
keep_sink: bool = False, keep_sink: bool = False,
keep_recent: bool = False, keep_recent: bool = False,
norm: float = 1.0, norm: float = 1.0,
use_bsa: bool = True,
): ):
""" """
Initialize XAttention policy. Initialize XAttention policy.
@@ -61,28 +53,19 @@ class XAttentionPolicy(AttentionPolicy):
Args: Args:
stride: Stride for reorganizing Q/K (default: 8) stride: Stride for reorganizing Q/K (default: 8)
threshold: Block selection threshold, 0-1 (default: 0.9) threshold: Block selection threshold, 0-1 (default: 0.9)
block_size: Block size for sparse attention (default: 128, must match BSA) chunk_size: Chunk size for estimation (auto if None)
chunk_size: Chunk size for estimation (default: 16384)
use_triton: Use Triton kernels (requires SM 80+) use_triton: Use Triton kernels (requires SM 80+)
keep_sink: Always keep first block (sink tokens) keep_sink: Always keep first block (sink tokens)
keep_recent: Always keep recent diagonal blocks keep_recent: Always keep recent diagonal blocks
norm: Normalization factor for attention scores norm: Normalization factor for attention scores
use_bsa: Use Block Sparse Attention library (default: True)
""" """
self.stride = stride self.stride = stride
self.threshold = threshold self.threshold = threshold
self.block_size = block_size
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.use_triton = use_triton self.use_triton = use_triton
self.keep_sink = keep_sink self.keep_sink = keep_sink
self.keep_recent = keep_recent self.keep_recent = keep_recent
self.norm = norm self.norm = norm
self.use_bsa = use_bsa
# BSA requires block_size = 128
if self.use_bsa and self.block_size != BSA_BLOCK_SIZE:
print(f"XAttention: BSA requires block_size=128, adjusting from {self.block_size}")
self.block_size = BSA_BLOCK_SIZE
# Check Triton availability # Check Triton availability
if self.use_triton: if self.use_triton:
@@ -96,207 +79,380 @@ class XAttentionPolicy(AttentionPolicy):
self.use_triton = False self.use_triton = False
print("XAttention: Triton not available. Falling back to PyTorch.") print("XAttention: Triton not available. Falling back to PyTorch.")
# Check BSA availability def select_blocks(
if self.use_bsa:
try:
from block_sparse_attn import block_sparse_attn_func
except ImportError:
self.use_bsa = False
print("XAttention: block_sparse_attn not available. Falling back to FlashAttention.")
def estimate(
self, self,
q: torch.Tensor, available_blocks: List[int],
k: torch.Tensor, ctx: PolicyContext,
layer_id: int, ) -> List[int]:
) -> Optional[torch.Tensor]:
""" """
Estimate sparse attention mask using XAttention algorithm. Select blocks for decode phase.
Calls xattn_estimate() from nanovllm.ops.xattn to compute block-level XAttention is prefill-only, so this method is only used as a fallback.
importance scores and generate a sparse boolean mask. Returns all available blocks by default.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
or None if estimation fails (fallback to full attention)
""" """
try: # XAttention is prefill-only, but we need to implement this abstract method
from nanovllm.ops.xattn import xattn_estimate # Since requires_block_selection=False, this won't be called for loading
return available_blocks
seq_len, num_heads, head_dim = q.shape def sparse_prefill_attention(
num_kv_heads = k.shape[1]
# Convert to [batch, heads, seq, dim] format expected by xattn_estimate
q_bhsd = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
k_bhsd = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, seq_len, head_dim]
# Handle GQA: expand k to match q heads for estimation
if num_kv_heads != num_heads:
# GQA: expand k by repeating
repeat_factor = num_heads // num_kv_heads
k_bhsd = k_bhsd.repeat(1, repeat_factor, 1, 1)
# Call xattn_estimate
attn_sums, sparse_mask = xattn_estimate(
q_bhsd, k_bhsd,
block_size=self.block_size,
stride=self.stride,
norm=self.norm,
threshold=self.threshold,
chunk_size=self.chunk_size,
use_triton=self.use_triton,
causal=True,
keep_sink=self.keep_sink,
keep_recent=self.keep_recent,
)
return sparse_mask
except Exception as e:
# If estimation fails, return None to use full attention
print(f"XAttention estimate failed: {e}, falling back to full attention")
return None
def compute_prefill(
self, self,
q: torch.Tensor, q: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
layer_id: int, layer_id: int,
softmax_scale: float,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Compute XAttention sparse prefill attention. Compute XAttention sparse attention for prefill.
Flow:
1. Call estimate() to get sparse mask
2. If mask is None or BSA unavailable, use full FlashAttention
3. Otherwise, use block_sparse_attn_func with mask
Args: Args:
q: Query tensor [seq_len, num_heads, head_dim] q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim] k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim] v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index layer_id: Current transformer layer index
softmax_scale: Softmax scaling factor
Returns: Returns:
Attention output [seq_len, num_heads, head_dim] Attention output [seq_len, num_heads, head_dim]
""" """
# If BSA is disabled, use full attention directly (skip estimation) seq_len = q.shape[0]
if not self.use_bsa: num_heads = q.shape[1]
return self._full_attention(q, k, v, softmax_scale) head_dim = q.shape[2]
# Step 1: Estimate sparse mask
sparse_mask = self.estimate(q, k, layer_id)
# Step 2: Compute attention
if sparse_mask is None:
# Estimation failed, fallback to full FlashAttention
return self._full_attention(q, k, v, softmax_scale)
# Use block sparse attention with mask
return self._block_sparse_attention(q, k, v, sparse_mask, softmax_scale)
def _block_sparse_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sparse_mask: torch.Tensor,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute block sparse attention using MIT-HAN-LAB BSA library.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
sparse_mask: Block mask [batch, num_heads, q_blocks, k_blocks]
softmax_scale: Softmax scaling factor
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from block_sparse_attn import block_sparse_attn_func
seq_len, num_heads, head_dim = q.shape
num_kv_heads = k.shape[1] num_kv_heads = k.shape[1]
# Handle GQA: expand K/V to match Q heads # Use FlashAttention directly for CPU offload mode
if num_kv_heads != num_heads: # FlashAttention supports GQA natively
repeat_factor = num_heads // num_kv_heads try:
k = k.repeat_interleave(repeat_factor, dim=1)
v = v.repeat_interleave(repeat_factor, dim=1)
# Cumulative sequence lengths (batch=1)
cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
# Head mask type: 1 for all heads using block sparse
head_mask_type = torch.ones(num_heads, dtype=torch.int32, device=q.device)
# Trim sparse_mask to actual block counts
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
block_mask = sparse_mask[:, :, :q_blocks, :k_blocks].contiguous()
# Call BSA
attn_output = block_sparse_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
None, # streaming_info (left_mask)
block_mask,
seq_len, seq_len,
p_dropout=0.0,
deterministic=True,
softmax_scale=softmax_scale,
is_causal=True,
)
return attn_output
def _full_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute full causal attention using FlashAttention.
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
softmax_scale: Softmax scaling factor
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
from flash_attn.flash_attn_interface import flash_attn_varlen_func from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device) cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func( attn_output = flash_attn_varlen_func(
q, k, v, q, k, v,
cu_seqlens_q=cu_seqlens, cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens, cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len, max_seqlen_q=seq_len,
max_seqlen_k=seq_len, max_seqlen_k=seq_len,
softmax_scale=softmax_scale, softmax_scale=1.0 / math.sqrt(head_dim),
causal=True, causal=True,
) )
return attn_output
except Exception as e:
# Fallback: PyTorch SDPA (supports GQA natively)
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
is_causal=True,
scale=1.0 / math.sqrt(head_dim)
)
return attn_output
def _xattn_offload_prefill(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
causal: bool = True,
) -> torch.Tensor:
"""
Simplified XAttention prefill for CPU offload mode.
Uses FlashAttention with full context since chunked estimation
with full key_states requires special handling.
"""
batch_size, num_heads, q_len, head_dim = query_states.shape
_, _, k_len, _ = key_states.shape
# Use FlashAttention with full context
# In offload mode, keys are already on CPU and loaded as needed
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
# Convert to [seq, heads, dim] format
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=k_len,
softmax_scale=1.0 / math.sqrt(head_dim),
causal=causal,
)
# Convert back to [batch, seq, heads, dim]
attn_output = attn_output.unsqueeze(0).transpose(1, 2) # [1, q_len, num_heads, head_dim]
return attn_output
except Exception as e:
# Final fallback: PyTorch SDPA
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=None,
is_causal=causal,
scale=1.0 / math.sqrt(head_dim)
)
return attn_output
def _xattn_prefill(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
stride: int,
norm: float,
threshold: float,
block_size: int = 128,
use_triton: bool = True,
causal: bool = True,
chunk_size: Optional[int] = None,
keep_sink: bool = False,
keep_recent: bool = False,
) -> torch.Tensor:
"""
XAttention prefill implementation.
Args:
query_states: [batch, num_heads, q_len, head_dim]
key_states: [batch, num_heads, k_len, head_dim]
value_states: [batch, num_heads, k_len, head_dim]
... other params
Returns:
Attention output [batch, q_len, num_heads, head_dim]
"""
batch_size, num_heads, k_len, head_dim = key_states.shape
_, _, q_len, _ = query_states.shape
# Auto-compute chunk_size if not specified
if chunk_size is None:
chunk_size = int(
max(
min(
max(2048, 1 << (k_len - 1).bit_length()),
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()),
),
2048,
)
)
# Phase 1: Estimate sparse pattern
attn_sums, approx_simple_mask = self._xattn_estimate(
query_states,
key_states,
block_size=block_size,
stride=stride,
norm=norm,
threshold=threshold,
chunk_size=chunk_size,
use_triton=use_triton,
causal=causal,
keep_sink=keep_sink,
keep_recent=keep_recent,
)
# Phase 2: Block sparse attention
# For now, use FlashAttention as fallback since block_sparse_attn_func may not be available
attn_output = self._block_sparse_attention_fallback(
query_states, key_states, value_states,
approx_simple_mask, block_size, q_len, k_len
)
return attn_output
def _xattn_estimate(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
block_size: int,
stride: int,
norm: float = 1,
softmax: bool = True,
threshold: float = 0.9,
chunk_size: int = 16384,
use_triton: bool = True,
causal: bool = True,
keep_sink: bool = False,
keep_recent: bool = False,
) -> torch.Tensor:
"""
Estimate sparse attention pattern using chunked computation.
Returns:
attn_sums: [batch, heads, q_blocks, k_blocks] - importance scores
simple_masks: [batch, heads, q_blocks, k_blocks] - boolean masks
"""
batch_size, num_kv_head, k_len, head_dim = key_states.shape
batch_size, num_q_head, q_len, head_dim = query_states.shape
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
k_block_num = (k_len + k_num_to_pad) // block_size
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
q_block_num = (q_len + q_num_to_pad) // block_size
# Pad inputs
if k_num_to_pad > 0:
pad_key_states = F.pad(key_states, (0, 0, 0, k_num_to_pad), value=0)
else:
pad_key_states = key_states
if q_num_to_pad > 0:
pad_query_states = F.pad(query_states, (0, 0, 0, q_num_to_pad), value=0)
else:
pad_query_states = query_states
reshaped_chunk_size = chunk_size // stride
reshaped_block_size = block_size // stride
k_reshaped_seq_len = (k_len + k_num_to_pad) // stride
attn_sum_list = []
simple_mask_list = []
for chunk_idx in range(q_chunk_num):
if use_triton:
# Triton GEMM + Softmax
attn_weights_slice = flat_group_gemm_fuse_reshape(
pad_query_states[:, :, (chunk_idx * reshaped_chunk_size) * stride : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size) * stride, :],
pad_key_states,
stride,
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
is_causal=causal,
)
attn_sum = softmax_fuse_block_sum(
attn_weights_slice,
reshaped_block_size,
min(4096, reshaped_block_size),
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
k_reshaped_seq_len - (k_num_to_pad // stride),
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
is_causal=causal,
)
else:
# PyTorch fallback
chunk_size_actual = reshaped_chunk_size
chunk_start = chunk_idx * chunk_size_actual
chunk_end = chunk_start + chunk_size_actual
chunked_query = pad_query_states[:, :, chunk_start * stride:chunk_end * stride:stride, :]
attn_weights_slice = torch.matmul(chunked_query, pad_key_states.transpose(2, 3))
attn_weights_slice = attn_weights_slice / math.sqrt(head_dim) / stride / norm
if causal:
causal_mask = torch.zeros((batch_size, num_q_head, chunk_size_actual, chunk_size_actual * k_chunk_num), device=key_states.device)
causal_mask[:, :, :, -(k_num_to_pad // stride):] = float("-inf")
# ... more causal mask logic ...
attn_weights_slice = attn_weights_slice + causal_mask
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1, dtype=torch.float32)
attn_sum = attn_weights_slice.view(batch_size, num_q_head, chunk_size_actual // reshaped_block_size, reshaped_block_size, -1).sum(dim=-1).sum(dim=-2)
# Find blocks based on threshold
simple_mask = find_blocks_chunked(
attn_sum,
k_block_num - q_block_num + chunk_idx * (reshaped_chunk_size // reshaped_block_size),
threshold,
None,
decoding=False,
mode="prefill",
causal=causal,
)
attn_sum_list.append(attn_sum)
simple_mask_list.append(simple_mask)
attn_sums = torch.cat(attn_sum_list, dim=-2)
simple_masks = torch.cat(simple_mask_list, dim=-2)
# Apply causal mask to block masks
if causal:
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=key_states.device), diagonal=0),
simple_masks[:, :, -q_block_num:, -q_block_num:],
False,
)
if keep_sink:
simple_masks[:, :, 0, :] = True
if keep_recent:
eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool)
eye_matrix_expanded = eye_matrix.unsqueeze(0).unsqueeze(0).expand(1, num_q_head, q_block_num, q_block_num)
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:]
)
return attn_sums, simple_masks
def _block_sparse_attention_fallback(
self,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
mask: torch.Tensor,
block_size: int,
q_len: int,
k_len: int,
) -> torch.Tensor:
"""
Fallback implementation using FlashAttention.
Since block_sparse_attn_func may not be available in all environments,
this uses standard FlashAttention with full attention.
"""
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_func
batch_size, num_heads, _, head_dim = query_states.shape
# Convert to [seq, heads, dim] format
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
attn_output = flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=k_len,
softmax_scale=1.0 / math.sqrt(head_dim),
causal=True,
)
# Convert back to [batch, seq, heads, dim]
attn_output = attn_output.unsqueeze(0).transpose(1, 2)
return attn_output
except Exception as e:
# Final fallback: PyTorch SDPA
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=None,
is_causal=True,
scale=1.0 / math.sqrt(query_states.shape[-1])
)
return attn_output
def reset(self) -> None: def reset(self) -> None:
"""Reset policy state (no state to reset for XAttention).""" """Reset policy state (no state to reset for XAttention)."""
pass pass
@@ -305,6 +461,4 @@ class XAttentionPolicy(AttentionPolicy):
return (f"XAttentionPolicy(" return (f"XAttentionPolicy("
f"stride={self.stride}, " f"stride={self.stride}, "
f"threshold={self.threshold}, " f"threshold={self.threshold}, "
f"block_size={self.block_size}, " f"use_triton={self.use_triton})")
f"use_triton={self.use_triton}, "
f"use_bsa={self.use_bsa})")

View File

@@ -98,10 +98,10 @@ class Attention(nn.Module):
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q, max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k, max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
softmax_scale=self.scale, causal=True, block_table=context.block_tables) softmax_scale=self.scale, causal=True, block_table=context.block_tables)
elif context.attention_policy is not None: elif context.sparse_prefill_policy is not None:
# Attention via policy (GPU-only) - delegate to policy # Sparse prefill (GPU-only) - delegate to policy
o = context.attention_policy.compute_prefill( o = context.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, self.layer_id, softmax_scale=self.scale q, k, v, self.layer_id
) )
else: else:
o = flash_attn_varlen_func(q, k, v, o = flash_attn_varlen_func(q, k, v,

View File

@@ -1,38 +0,0 @@
"""
Operators module for nano-vLLM.
This module contains low-level attention operators and kernels.
"""
from nanovllm.ops.chunked_attention import (
flash_attn_with_lse,
merge_attention_outputs,
chunked_attention_varlen,
ChunkedPrefillState,
)
from nanovllm.ops.xattn import (
xattn_estimate,
xattn_estimate_chunked,
flat_group_gemm_fuse_reshape,
softmax_fuse_block_sum,
find_blocks_chunked,
create_causal_mask,
compute_sparsity,
)
__all__ = [
# chunked_attention
"flash_attn_with_lse",
"merge_attention_outputs",
"chunked_attention_varlen",
"ChunkedPrefillState",
# xattn
"xattn_estimate",
"xattn_estimate_chunked",
"flat_group_gemm_fuse_reshape",
"softmax_fuse_block_sum",
"find_blocks_chunked",
"create_causal_mask",
"compute_sparsity",
]

View File

@@ -1,624 +0,0 @@
"""
Chunked attention implementation for CPU KV cache offloading.
This module implements flash attention with LSE (log-sum-exp) output,
enabling proper online softmax merging for chunked prefill.
Key functions:
- flash_attn_with_lse: Flash attention that returns output and LSE
- merge_attention_outputs: Merge outputs from multiple KV chunks
- chunked_prefill_attention: High-level interface for chunked attention
"""
import math
import torch
import triton
import triton.language as tl
from typing import Tuple, List, Optional
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_kernel_with_lse(
Q,
K,
V,
Out,
Lse,
softmax_scale,
stride_qb,
stride_qh,
stride_qm,
stride_kb,
stride_kh,
stride_kn,
stride_vb,
stride_vh,
stride_vn,
stride_ob,
stride_oh,
stride_om,
nheads,
seqlen_q,
seqlen_k,
seqlen_q_rounded,
headdim,
CACHE_KEY_SEQLEN_Q,
CACHE_KEY_SEQLEN_K,
IS_CAUSAL: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr,
EVEN_M: tl.constexpr,
EVEN_N: tl.constexpr,
EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""
Flash attention forward kernel with LSE output.
Implements standard Flash Attention online softmax algorithm:
- m_i: running max of attention scores
- l_i: running sum of exp(scores - m_i)
- acc_o: running sum of softmax(scores) @ V (unnormalized)
Final output: acc_o / l_i
Final LSE: m_i + log(l_i)
"""
start_m = tl.program_id(0)
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_HEADDIM)
# Pointers
q_ptrs = (
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
)
k_ptrs = (
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
)
v_ptrs = (
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
)
# Initialize running statistics
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized)
# Load Q (once per block)
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
)
# Loop over K, V blocks
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# Load K
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
k = tl.load(
k_ptrs + start_n * stride_kn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# Compute QK^T * scale
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, tl.trans(k))
qk *= softmax_scale
# Apply masks
if not EVEN_N:
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
# Online softmax: compute block max
m_ij = tl.max(qk, 1) # [BLOCK_M]
# New running max
m_new = tl.maximum(m_i, m_ij) # [BLOCK_M]
# Rescale factor for previous accumulator
alpha = tl.exp(m_i - m_new) # [BLOCK_M]
# Compute P = exp(qk - m_new)
p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
# Sum of current block
l_ij = tl.sum(p, 1) # [BLOCK_M]
# Update running sum: l_new = l_i * alpha + l_ij
l_new = l_i * alpha + l_ij
# Rescale previous output and add new contribution
acc_o = acc_o * alpha[:, None]
# Load V
if EVEN_N & EVEN_M:
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# acc_o += P @ V
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# Update running statistics
m_i = m_new
l_i = l_new
# Final normalization: output = acc_o / l_i
acc_o = acc_o / l_i[:, None]
# Compute LSE = m_i + log(l_i)
lse_i = m_i + tl.log(l_i)
# Store LSE
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
if EVEN_M:
tl.store(lse_ptrs, lse_i)
else:
tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q)
# Store output
out_ptrs = (
Out
+ off_b * stride_ob
+ off_h * stride_oh
+ (offs_m[:, None] * stride_om + offs_d[None, :])
)
if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o)
else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
def flash_attn_with_lse(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Flash attention forward pass that returns both output and LSE.
Uses flash_attn library which natively supports GQA without memory overhead.
Args:
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
causal: Whether to apply causal masking
Returns:
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
"""
from flash_attn.flash_attn_interface import flash_attn_func
batch, seqlen_q, nheads_q, headdim = q.shape
_, seqlen_k, nheads_kv, _ = k.shape
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
# Use flash_attn_func which natively supports GQA (no memory overhead)
# It returns (output, softmax_lse) when return_attn_probs=True is not set
# We need to use the internal function to get LSE
out, lse, _ = flash_attn_func(
q, k, v,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
)
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
# Trim to actual seqlen_q
lse = lse[:, :, :seqlen_q]
return out, lse
@triton.jit
def _merge_lse_kernel(
lse1_ptr, lse2_ptr, lse_out_ptr,
num_elements: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging LSE values.
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
"""
# Each program handles BLOCK_SIZE elements
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
# Load lse values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
# Compute max for numerical stability (in fp32)
max_lse = tl.maximum(lse1, lse2)
# Compute exp(lse - max_lse) in fp32
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
lse_merged = max_lse + tl.log(exp1 + exp2)
# Store result (convert back to original dtype)
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
@triton.jit
def _merge_output_kernel(
o1_ptr, o2_ptr, lse1_ptr, lse2_ptr, o_out_ptr,
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""Fused kernel for merging attention outputs.
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
This is critical for numerical accuracy in chunked attention.
"""
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
pid_batch = tl.program_id(0)
pid_seq = tl.program_id(1)
pid_head = tl.program_id(2)
# Compute LSE index: [batch, nheads, seqlen_q]
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
# Load LSE values and convert to fp32 for precision
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
# Compute max and scaling factors in fp32
max_lse = tl.maximum(lse1, lse2)
exp1 = tl.exp(lse1 - max_lse)
exp2 = tl.exp(lse2 - max_lse)
sum_exp = exp1 + exp2
# Process headdim in chunks
for d_offset in range(0, headdim, BLOCK_SIZE):
d_idx = d_offset + tl.arange(0, BLOCK_SIZE)
mask = d_idx < headdim
# Compute output index: [batch, seqlen_q, nheads, headdim]
base_idx = (pid_batch * seqlen_q * nheads * headdim +
pid_seq * nheads * headdim +
pid_head * headdim)
o_idx = base_idx + d_idx
# Load o1, o2 and convert to fp32 for weighted sum
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
# Store result (Triton will convert back to original dtype)
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
def merge_attention_outputs(
o1: torch.Tensor,
lse1: torch.Tensor,
o2: torch.Tensor,
lse2: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Merge two attention outputs using online softmax (Triton fused kernel).
This implements the online softmax merging formula:
- m_new = max(lse1, lse2)
- o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new))
- lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new))
Args:
o1: First output [batch, seqlen_q, nheads, headdim]
lse1: First LSE [batch, nheads, seqlen_q]
o2: Second output [batch, seqlen_q, nheads, headdim]
lse2: Second LSE [batch, nheads, seqlen_q]
Returns:
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
lse_merged: Merged LSE [batch, nheads, seqlen_q]
"""
batch, seqlen_q, nheads, headdim = o1.shape
# Allocate output tensors
o_merged = torch.empty_like(o1)
lse_merged = torch.empty_like(lse1)
# Launch LSE merge kernel
num_lse_elements = batch * nheads * seqlen_q
BLOCK_SIZE_LSE = 256
grid_lse = (triton.cdiv(num_lse_elements, BLOCK_SIZE_LSE),)
_merge_lse_kernel[grid_lse](
lse1, lse2, lse_merged,
num_lse_elements,
BLOCK_SIZE=BLOCK_SIZE_LSE,
)
# Launch output merge kernel
BLOCK_SIZE = 128
grid_output = (batch, seqlen_q, nheads)
_merge_output_kernel[grid_output](
o1, o2, lse1, lse2, o_merged,
batch, seqlen_q, nheads, headdim,
BLOCK_SIZE=BLOCK_SIZE,
)
return o_merged, lse_merged
def chunked_attention_varlen(
q: torch.Tensor,
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
cu_seqlens_q: torch.Tensor,
cu_seqlens_k_list: List[torch.Tensor],
max_seqlen_q: int,
max_seqlen_k_list: List[int],
softmax_scale: Optional[float] = None,
causal_mask_per_chunk: Optional[List[bool]] = None,
) -> torch.Tensor:
"""
Compute attention with KV split across multiple chunks.
This is the core function for chunked prefill. It computes attention
against each KV chunk and merges results using online softmax.
For causal attention with chunked KV:
- First chunk (current tokens): Apply causal mask
- Previous chunks: No causal mask (all previous tokens are valid context)
Args:
q: Query tensor [total_q_tokens, nheads, headdim]
kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim]
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk
max_seqlen_q: Maximum query sequence length
max_seqlen_k_list: List of maximum key sequence lengths for each chunk
softmax_scale: Scaling factor
causal_mask_per_chunk: Whether to apply causal mask for each chunk
Returns:
out: Output tensor [total_q_tokens, nheads, headdim]
"""
if len(kv_chunks) == 0:
raise ValueError("Need at least one KV chunk")
nheads = q.shape[1]
headdim = q.shape[2]
batch = cu_seqlens_q.shape[0] - 1
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(headdim)
if causal_mask_per_chunk is None:
# Default: causal for last chunk only
causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True]
# Initialize accumulated output and LSE
accumulated_o = None
accumulated_lse = None
for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
is_causal = causal_mask_per_chunk[chunk_idx]
# Reshape Q for batch processing
# For varlen, we need to handle each sequence separately
# For simplicity, assume single sequence (batch=1) for now
q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim]
# Compute attention for this chunk
chunk_o, chunk_lse = flash_attn_with_lse(
q_batched,
k_chunk,
v_chunk,
softmax_scale=softmax_scale,
causal=is_causal,
)
# Merge with accumulated
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse,
)
# Remove batch dimension
return accumulated_o.squeeze(0)
class ChunkedPrefillState:
"""
State for tracking chunked prefill progress.
This class maintains the accumulated attention output and LSE
across multiple prefill chunks.
"""
def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device):
self.num_layers = num_layers
self.dtype = dtype
self.device = device
# Per-layer accumulated outputs
# Each entry: (accumulated_output, accumulated_lse) or None
self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
None for _ in range(num_layers)
]
# Track which chunks have been processed
self.processed_chunks: int = 0
def update_layer(
self,
layer_id: int,
chunk_output: torch.Tensor,
chunk_lse: torch.Tensor,
):
"""Update accumulated state for a layer with a new chunk's output."""
if self.layer_states[layer_id] is None:
self.layer_states[layer_id] = (chunk_output, chunk_lse)
else:
acc_o, acc_lse = self.layer_states[layer_id]
merged_o, merged_lse = merge_attention_outputs(
acc_o, acc_lse,
chunk_output, chunk_lse,
)
self.layer_states[layer_id] = (merged_o, merged_lse)
def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]:
"""Get the final accumulated output for a layer."""
if self.layer_states[layer_id] is None:
return None
return self.layer_states[layer_id][0]
def clear(self):
"""Clear all accumulated state."""
self.layer_states = [None for _ in range(self.num_layers)]
self.processed_chunks = 0
# Test function
def _test_chunked_attention():
"""Test chunked attention using flash_attn_with_lse and merge_attention_outputs."""
from flash_attn.flash_attn_interface import flash_attn_func
torch.manual_seed(42)
print("=" * 70)
print("Test: Chunked attention vs flash_attn_func (non-causal)")
print("=" * 70)
print("Splitting K,V into chunks, computing attention per chunk, then merging")
print()
for dtype in [torch.float16, torch.bfloat16]:
for num_chunks in [64, 128, 256]:
for batch, seqlen, nheads, headdim in [
(1, 1024, 32, 128),
(1, 2048, 32, 128),
(1, 4096, 32, 128),
(1, 8192, 32, 128),
]:
# Generate random Q, K, V
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
# Reference: full attention (non-causal)
out_ref = flash_attn_func(q, k, v, causal=False)
# Chunked attention: split K, V into chunks
chunk_size = seqlen // num_chunks
accumulated_o = None
accumulated_lse = None
for i in range(num_chunks):
start = i * chunk_size
end = (i + 1) * chunk_size
k_chunk = k[:, start:end, :, :]
v_chunk = v[:, start:end, :, :]
# Q attends to this K,V chunk (non-causal)
chunk_o, chunk_lse = flash_attn_with_lse(
q, k_chunk, v_chunk, causal=False
)
if accumulated_o is None:
accumulated_o = chunk_o
accumulated_lse = chunk_lse
else:
# Merge with previous chunks
accumulated_o, accumulated_lse = merge_attention_outputs(
accumulated_o, accumulated_lse,
chunk_o, chunk_lse
)
# Compare
out_diff = (out_ref - accumulated_o).abs()
out_max_diff = out_diff.max().item()
out_mean_diff = out_diff.mean().item()
status = "PASS" if out_max_diff < 1e-2 else "FAIL"
print(
f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} "
f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) "
f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}"
)
print()
print("=" * 70)
print("Test completed!")
if __name__ == "__main__":
_test_chunked_attention()

File diff suppressed because it is too large Load Diff

View File

@@ -14,9 +14,9 @@ class Context:
context_lens: torch.Tensor | None = None context_lens: torch.Tensor | None = None
block_tables: torch.Tensor | None = None block_tables: torch.Tensor | None = None
# Attention policy support (GPU-only path) # Sparse prefill attention support (GPU-only path)
# When set, uses policy.compute_prefill() instead of FlashAttention # When set, uses policy.sparse_prefill_attention() instead of FlashAttention
attention_policy: Any = None # AttentionPolicy instance sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
_CONTEXT = Context() _CONTEXT = Context()
@@ -35,7 +35,7 @@ def set_context(
slot_mapping=None, slot_mapping=None,
context_lens=None, context_lens=None,
block_tables=None, block_tables=None,
attention_policy=None, sparse_prefill_policy=None,
): ):
global _CONTEXT global _CONTEXT
_CONTEXT = Context( _CONTEXT = Context(
@@ -47,7 +47,7 @@ def set_context(
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
context_lens=context_lens, context_lens=context_lens,
block_tables=block_tables, block_tables=block_tables,
attention_policy=attention_policy, sparse_prefill_policy=sparse_prefill_policy,
) )

130
notes.md
View File

@@ -1,130 +0,0 @@
# Notes: SparsePolicy Refactoring Research
## Sources
### Source 1: tzj/minference branch - policy.py
- 路径: `nanovllm/kvcache/sparse/policy.py`
- 关键设计:
- `PolicyContext` 数据类包含 query_chunk_idx, num_query_chunks, layer_id, query, is_prefill 等
- `select_blocks()` 需要 offload_engine 参数
- `compute_chunked_prefill()``compute_chunked_decode()` 是完整的 attention 流程
- `on_prefill_offload()` / `on_decode_offload()` hooks 用于收集元数据
### Source 2: tzj/minference branch - full_policy.py
- 路径: `nanovllm/kvcache/sparse/full_policy.py`
- 关键实现:
- `compute_chunked_prefill()` 内部使用 ring buffer pipeline 加载 blocks
- 使用 `flash_attn_with_lse``merge_attention_outputs` 合并多个 chunk 的 attention
- `compute_chunked_decode()` 处理 prefilled blocks + decode buffer
### Source 3: tzj/layer-offload branch - model_runner.py
- 路径: `nanovllm/engine/model_runner.py`
- 关键设计:
- `run_layerwise_offload_prefill()` 逐层处理,每层计算完整 attention
- `sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)` 简单接口
- FULL policy 通过 `if sparse_prefill_policy is None` 走 else 分支
### Source 4: tzj/layer-offload branch - xattn.py
- 路径: `nanovllm/kvcache/sparse/xattn.py`
- 关键实现:
- `sparse_prefill_attention()` 直接使用 FlashAttention因为 chunked prefill 架构限制)
- 保留 Triton kernels 供未来 GPU-only 模式
## Synthesized Findings
### 架构差异总结
| 方面 | Chunked Offload | Layerwise Offload |
|------|-----------------|-------------------|
| **Prefill 流程** | chunk-by-chunk跨层 | layer-by-layer完整序列 |
| **KV 存储** | 每 chunk 立即 offload | 每层计算后 offload |
| **Attention 计算** | 分多次计算+合并 | 一次完整计算 |
| **Block 加载** | 需要从 CPU 加载历史 | 不需要,已在 GPU |
| **Policy 责任** | 完整 attention 流程 | 仅 attention kernel 选择 |
### Layerwise Offload 的简化点
1. **不需要 block selection**: 整层 KV 都在 GPU无需选择
2. **不需要 offload_engine 参数**: Policy 不负责加载 KV
3. **不需要 merge_attention_outputs**: 一次计算完整 attention
4. **不需要 offload hooks**: offload 在 model_runner 统一处理
### 设计建议
1. **保持接口简单**: 只需要 `compute_prefill_attention()``compute_decode_attention()`
2. **FULL 也实现方法**: 不再通过 `is None` 判断,所有 policy 统一调用
3. **移除不必要的参数**: 不需要 offload_engine, kvcache_manager, seq 等
4. **统一命名**: 使用 `compute_*_attention` 而不是 `sparse_prefill_attention`
## Code Examples
### 当前调用方式 (model_runner.py:876-891)
```python
# Sparse or Full attention
if self.sparse_prefill_policy is not None:
# MInference or other sparse prefill policy
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
q, k, v, layer_id
)
else:
# Full attention using FlashAttention
attn_output = flash_attn_varlen_func(
q, k, v, ...
)
```
### 建议的新调用方式
```python
# 所有 policy 统一调用
attn_output = self.attention_policy.compute_prefill_attention(
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
)
```
## Questions Resolved
- Q: 是否需要 PolicyContext?
- A: 可以简化,因为 layerwise 模式下不需要 chunk 信息
- Q: decode 阶段如何处理?
- A: **Decode 不需要 policy**!当前 `run_layerwise_offload_decode()` 使用标准 `layer(positions, hidden_states, residual)` 调用,走 Attention.forward() 路径
- Q: 为什么 decode 不需要 sparse?
- A: 因为 decode 每次只有 1 个 token没有稀疏化的意义。KV 从 ring buffer 加载后直接用 flash_attn_with_kvcache
## Key Insight
**Layerwise Offload 的 Policy 设计应该只关注 Prefill**
```
Prefill: 需要 Policy
- 整个序列一次计算 attention
- 可以使用 sparse attention 方法(如 MInference 的 vertical+slash pattern
- Policy 接收 q, k, v, layer_id, softmax_scale
Decode: 不需要 Policy
- 每次只有 1 个 token query
- KV 从 ring buffer 加载
- 使用标准 flash_attn_with_kvcache
```
## Interface Comparison Summary
| 方面 | tzj/minference | tzj/layer-offload (新设计) |
|------|----------------|---------------------------|
| 类名 | SparsePolicy | AttentionPolicy |
| Prefill 方法 | compute_chunked_prefill() | compute_attention() |
| Decode 方法 | compute_chunked_decode() | 不需要(用标准路径) |
| 需要 offload_engine | 是 | 否 |
| 需要 kvcache_manager | 是 | 否 |
| 需要 seq | 是 | 否 |
| 支持 FULL | 是 | 是 |
## Migration Path
1. 保留 `SparsePolicy` 作为 `AttentionPolicy` 的别名
2. 保留 `PolicyContext` 供未来扩展
3. 保留 `select_blocks()` 方法签名(虽然不使用)
4. 移除 `requires_block_selection` 属性(不需要)

View File

@@ -1,549 +0,0 @@
# Task Plan: Refactor SparsePolicy for Layerwise Offload
## Goal
重构 SparsePolicy 接口,参考 tzj/minference 分支的设计模式,使所有 attention 都可以抽象成 policy并按统一规范编写。适配当前 layerwise offload 架构特点(整层 KV 在 GPU 上)。
## Background
### 两种 Offload 架构对比
| 特性 | tzj/minference (Chunked Offload) | tzj/layer-offload (Layerwise Offload) |
|------|----------------------------------|---------------------------------------|
| 处理粒度 | 每次一个 chunk (block_size tokens) | 每次一整层 (所有 tokens) |
| KV 位置 | 历史 chunks 在 CPU需要加载 | 整层 KV 都在 GPU |
| Policy 入口 | `compute_chunked_prefill()/decode()` | `compute_prefill()/decode()` |
| 需要 offload_engine | 是(加载 blocks | 否KV 已在 GPU |
| Mask 计算 | `select_blocks()` 返回 block IDs | `estimate()` 返回 sparse mask |
### tzj/minference 的 Policy 接口
```python
class SparsePolicy(ABC):
supports_prefill: bool
supports_decode: bool
@abstractmethod
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
@abstractmethod
def compute_chunked_prefill(self, q, k, v, layer_id, ..., offload_engine, ...) -> Tensor
@abstractmethod
def compute_chunked_decode(self, q, layer_id, ..., offload_engine, ...) -> Tensor
```
### 当前 branch 的 Policy 接口(重构前)
```python
class SparsePolicy(ABC):
supports_prefill: bool
supports_decode: bool
@abstractmethod
def select_blocks(self, available_blocks, ctx) -> List[int]
def sparse_prefill_attention(self, q, k, v, layer_id) -> Tensor
```
## Phases
- [x] Phase 1: 分析差异并设计新接口
- [x] **Phase 0: 创建 nanovllm.ops 模块** ✅ 测试通过
- [ ] Phase 2: 重构 AttentionPolicy 基类
- [ ] Phase 3: 重构 FullAttentionPolicy
- [ ] Phase 4: 重构 XAttentionPolicy (含 estimate 方法)
- [ ] Phase 5: 更新 model_runner 调用方式
- [ ] Phase 6: 测试验证
---
## Phase 0: 创建 nanovllm.ops 模块
### 目标
从 tzj/minference 分支提取 ops 模块,为 XAttention estimate 提供底层算子支持。
### 步骤
1. **创建目录结构**
```
nanovllm/ops/
├── __init__.py
├── xattn.py # xattn_estimate, xattn_estimate_chunked, Triton kernels
└── chunked_attention.py # flash_attn_with_lse, merge_attention_outputs (备用)
```
2. **从 tzj/minference 提取文件**
```bash
git show tzj/minference:nanovllm/ops/__init__.py > nanovllm/ops/__init__.py
git show tzj/minference:nanovllm/ops/xattn.py > nanovllm/ops/xattn.py
git show tzj/minference:nanovllm/ops/chunked_attention.py > nanovllm/ops/chunked_attention.py
```
3. **Cherry-pick 测试文件**
```bash
git show tzj/minference:tests/test_xattn_estimate_chunked.py > tests/test_xattn_estimate_chunked.py
```
4. **运行测试验证**
```bash
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/Worktree/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_chunked.py
```
### nanovllm/ops 模块内容
| 文件 | 核心函数 | 用途 |
|------|----------|------|
| `xattn.py` | `xattn_estimate()` | 标准 XAttention estimation |
| `xattn.py` | `xattn_estimate_chunked()` | Chunked prefill 版本 |
| `xattn.py` | `flat_group_gemm_fuse_reshape()` | Triton kernel: fused reshape + GEMM |
| `xattn.py` | `softmax_fuse_block_sum()` | Triton kernel: softmax + block sum |
| `xattn.py` | `find_blocks_chunked()` | Block selection based on threshold |
| `chunked_attention.py` | `flash_attn_with_lse()` | Flash attention with LSE output |
| `chunked_attention.py` | `merge_attention_outputs()` | Merge multiple attention chunks |
### 与 Policy 的关系
```
XAttentionPolicy.estimate()
└── 调用 nanovllm.ops.xattn.xattn_estimate()
├── flat_group_gemm_fuse_reshape() (Triton)
├── softmax_fuse_block_sum() (Triton)
└── find_blocks_chunked()
```
---
## Key Questions
1. **`select_blocks` 改为什么?**
- 改名为 `estimate()`:用于计算 sparse mask
- 对于 XAttention对应 COMPASS 的 `xattn_estimate()` 函数
- FullAttentionPolicy 的 `estimate()` 返回 None表示 full attention
2. **Policy 接口应该如何设计?**
- Prefill: `compute_prefill(q, k, v, layer_id, softmax_scale)`
- Decode: `compute_decode(q, k, v, layer_id, softmax_scale)`
- Estimate: `estimate(q, k, layer_id)` - 计算 sparse mask
3. **FULL policy 如何处理?**
- FULL 也实现 `compute_prefill/decode`,使用 FlashAttention
- `estimate()` 返回 None表示不进行稀疏化
## Proposed New Interface
```python
from abc import ABC, abstractmethod
from typing import Optional
import torch
class AttentionPolicy(ABC):
"""Layerwise Offload 模式下的 Attention Policy
所有 attention 计算都通过 policy 进行,包括 Full 和 Sparse。
支持 prefill 和 decode 两个阶段。
"""
supports_prefill: bool = True
supports_decode: bool = True
def estimate(
self,
q: torch.Tensor, # [seq_len, num_heads, head_dim]
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
layer_id: int,
) -> Optional[torch.Tensor]:
"""
估算 sparse attention mask。
对于 sparse policy如 XAttention计算哪些 blocks 需要 attend。
对于 full policy返回 None 表示使用完整 attention。
对应 COMPASS 的 xattn_estimate() 函数。
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask, 或 None
"""
return None # 默认为 full attention
@abstractmethod
def compute_prefill(
self,
q: torch.Tensor, # [seq_len, num_heads, head_dim]
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
计算 prefill attention。
整层 KV 都在 GPU 上,一次计算完整 attention。
可以先调用 estimate() 获取 sparse mask然后应用 block sparse attention。
Args:
q: Query tensor [seq_len, num_heads, head_dim]
k: Key tensor [seq_len, num_kv_heads, head_dim]
v: Value tensor [seq_len, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
Returns:
Attention output [seq_len, num_heads, head_dim]
"""
pass
def compute_decode(
self,
q: torch.Tensor, # [1, num_heads, head_dim]
k: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
v: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
计算 decode attention。
KV 从 ring buffer 提供,包含 prefill tokens + 已 decode 的 tokens。
Args:
q: Query tensor [1, num_heads, head_dim]
k: Key tensor [context_len+1, num_kv_heads, head_dim]
v: Value tensor [context_len+1, num_kv_heads, head_dim]
layer_id: Transformer layer index
softmax_scale: Softmax scaling factor
Returns:
Attention output [1, num_heads, head_dim]
"""
# 默认实现:使用 FlashAttention
from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None:
"""Reset policy state between sequences."""
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# 保留旧名称作为别名
SparsePolicy = AttentionPolicy
```
## Implementation Plan
### Phase 2: 重构 policy.py
```python
# nanovllm/kvcache/sparse/policy.py
from abc import ABC, abstractmethod
from typing import Optional
import torch
class AttentionPolicy(ABC):
"""Base class for attention policies in layerwise offload mode."""
supports_prefill: bool = True
supports_decode: bool = True
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
Estimate sparse attention mask.
For sparse policies (e.g., XAttention), computes block-level importance.
For full policy, returns None.
Corresponds to xattn_estimate() in COMPASS.
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] or None
"""
return None
@abstractmethod
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""Compute prefill attention."""
pass
def compute_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""Compute decode attention (default: FlashAttention)."""
from flash_attn.flash_attn_interface import flash_attn_varlen_func
context_len = k.shape[0]
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=1,
max_seqlen_k=context_len,
softmax_scale=softmax_scale,
causal=False,
)
def reset(self) -> None:
pass
def __repr__(self) -> str:
return f"{self.__class__.__name__}()"
# Backward compatibility alias
SparsePolicy = AttentionPolicy
```
### Phase 3: 重构 FullAttentionPolicy
```python
# nanovllm/kvcache/sparse/full_policy.py
import torch
from .policy import AttentionPolicy
class FullAttentionPolicy(AttentionPolicy):
"""Full attention using FlashAttention (no sparsity)."""
supports_prefill = True
supports_decode = True
def estimate(self, q, k, layer_id):
"""Full attention - no sparse mask needed."""
return None
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
def __repr__(self):
return "FullAttentionPolicy()"
```
### Phase 4: 重构 XAttentionPolicy
```python
# nanovllm/kvcache/sparse/xattn.py
import torch
from typing import Optional
from .policy import AttentionPolicy
class XAttentionPolicy(AttentionPolicy):
"""
XAttention sparse prefill policy.
Uses chunked estimation to compute sparse attention mask,
then applies block sparse attention.
"""
supports_prefill = True
supports_decode = True
def __init__(
self,
stride: int = 8,
threshold: float = 0.9,
block_size: int = 128,
chunk_size: int = 16384,
use_triton: bool = True,
):
self.stride = stride
self.threshold = threshold
self.block_size = block_size
self.chunk_size = chunk_size
self.use_triton = use_triton
def estimate(
self,
q: torch.Tensor,
k: torch.Tensor,
layer_id: int,
) -> Optional[torch.Tensor]:
"""
XAttention estimation (xattn_estimate).
Uses chunked GEMM + softmax to estimate block-level importance,
then selects important blocks based on threshold.
对应 COMPASS 的 xattn_estimate() 函数:
1. Pad inputs to chunk_size multiples
2. Reshape with stride
3. Compute QK^T in chunks (Triton)
4. Block-wise softmax + aggregation
5. Threshold-based selection
Args:
q: [seq_len, num_heads, head_dim]
k: [seq_len, num_kv_heads, head_dim]
layer_id: transformer layer index
Returns:
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask
or None (fallback to full attention)
"""
# TODO: 实现真正的 xattn_estimate
# 当前返回 None 使用 full attention
return None
def compute_prefill(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer_id: int,
softmax_scale: float,
) -> torch.Tensor:
"""
Compute XAttention sparse prefill.
Flow:
1. Call estimate() to get sparse mask
2. If mask is None, use full attention
3. Otherwise, apply block sparse attention with mask
"""
# Step 1: Estimate sparse mask
sparse_mask = self.estimate(q, k, layer_id)
# Step 2: Compute attention
if sparse_mask is None:
# Fallback to full attention
from flash_attn.flash_attn_interface import flash_attn_varlen_func
seq_len = q.shape[0]
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
return flash_attn_varlen_func(
q, k, v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seq_len,
max_seqlen_k=seq_len,
softmax_scale=softmax_scale,
causal=True,
)
else:
# Apply block sparse attention with mask
# 使用 block_sparse_attn_func(q, k, v, sparse_mask, block_size)
raise NotImplementedError("Block sparse attention not yet implemented")
def __repr__(self):
return (f"XAttentionPolicy("
f"stride={self.stride}, "
f"threshold={self.threshold}, "
f"block_size={self.block_size})")
```
### Phase 5: 更新 model_runner.py
```python
# model_runner.py - allocate_kv_cache()
# 改为总是创建 policy包括 FULL
from nanovllm.kvcache.sparse import create_attention_policy
self.attention_policy = create_attention_policy(config.attention_policy, **policy_kwargs)
logger.info(f"Attention policy: {self.attention_policy}")
# run_layerwise_offload_prefill() 和 run_gpu_only_prefill()
# 旧代码:
if self.sparse_prefill_policy is not None:
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
else:
attn_output = flash_attn_varlen_func(...)
# 新代码:
attn_output = self.attention_policy.compute_prefill(
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
)
```
## Method Mapping
| 旧方法 | 新方法 | 说明 |
|--------|--------|------|
| `select_blocks()` | `estimate()` | 计算 sparse mask对应 xattn_estimate |
| `sparse_prefill_attention()` | `compute_prefill()` | Prefill attention |
| (无) | `compute_decode()` | Decode attention默认实现 |
| `on_prefill_offload()` | (移除) | Offload 在 model_runner 处理 |
## Files to Modify
| File | Changes |
|------|---------|
| `nanovllm/kvcache/sparse/policy.py` | 新接口estimate, compute_prefill, compute_decode |
| `nanovllm/kvcache/sparse/full_policy.py` | 实现 compute_prefill(), estimate() 返回 None |
| `nanovllm/kvcache/sparse/xattn.py` | estimate() 对应 xattn_estimate, compute_prefill() |
| `nanovllm/kvcache/sparse/__init__.py` | 更新工厂函数 |
| `nanovllm/engine/model_runner.py` | 统一调用 attention_policy.compute_prefill() |
| `nanovllm/config.py` | 可选:重命名配置项 |
## Decisions Made
1. **方法命名**: `compute_prefill` / `compute_decode` 对应 chunked 版本的命名风格
2. **estimate 方法**: 替代 `select_blocks`,返回 sparse mask 而不是 block IDs
3. **XAttention**: `estimate()` 对应 COMPASS 的 `xattn_estimate()`
4. **Full Policy**: `estimate()` 返回 None 表示使用完整 attention
5. **Decode 默认实现**: 基类提供默认的 FlashAttention 实现
## Errors Encountered
- (无)
## Status
**Currently in Phase 1** - 完成分析和接口设计,等待用户确认后进入 Phase 2

View File

@@ -32,14 +32,11 @@ def run_needle_test(
enable_cpu_offload: bool = False, enable_cpu_offload: bool = False,
enable_quest: bool = False, enable_quest: bool = False,
enable_minference: bool = False, enable_minference: bool = False,
enable_xattn: bool = False,
sparse_topk: int = 8, sparse_topk: int = 8,
sparse_threshold: int = 4, sparse_threshold: int = 4,
minference_budget: float = 0.3, minference_budget: float = 0.3,
minference_vertical: int = 1000, minference_vertical: int = 1000,
minference_slash: int = 6096, minference_slash: int = 6096,
xattn_threshold: float = 0.9,
xattn_use_bsa: bool = True,
gpu_utilization: float = 0.9, gpu_utilization: float = 0.9,
enforce_eager: bool = True, enforce_eager: bool = True,
verbose: bool = True, verbose: bool = True,
@@ -59,14 +56,11 @@ def run_needle_test(
enable_cpu_offload: Enable CPU offload mode enable_cpu_offload: Enable CPU offload mode
enable_quest: Enable Quest sparse attention (decode-only Top-K) enable_quest: Enable Quest sparse attention (decode-only Top-K)
enable_minference: Enable MInference sparse prefill (GPU-only) enable_minference: Enable MInference sparse prefill (GPU-only)
enable_xattn: Enable XAttention sparse prefill with BSA
sparse_topk: Top-K blocks for Quest sparse_topk: Top-K blocks for Quest
sparse_threshold: Apply sparse only when blocks > threshold sparse_threshold: Apply sparse only when blocks > threshold
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode) minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
minference_vertical: Fixed vertical_size (only used when budget=None) minference_vertical: Fixed vertical_size (only used when budget=None)
minference_slash: Fixed slash_size (only used when budget=None) minference_slash: Fixed slash_size (only used when budget=None)
xattn_threshold: XAttention block selection threshold (0-1)
xattn_use_bsa: Use Block Sparse Attention library
gpu_utilization: GPU memory utilization fraction gpu_utilization: GPU memory utilization fraction
verbose: Print detailed output verbose: Print detailed output
@@ -74,9 +68,7 @@ def run_needle_test(
True if test passed, False otherwise True if test passed, False otherwise
""" """
# Determine sparse policy # Determine sparse policy
if enable_xattn: if enable_minference:
sparse_policy = SparsePolicyType.XATTN
elif enable_minference:
sparse_policy = SparsePolicyType.MINFERENCE sparse_policy = SparsePolicyType.MINFERENCE
elif enable_quest: elif enable_quest:
sparse_policy = SparsePolicyType.QUEST sparse_policy = SparsePolicyType.QUEST
@@ -102,8 +94,6 @@ def run_needle_test(
print(f" MInference: adaptive (budget={minference_budget})") print(f" MInference: adaptive (budget={minference_budget})")
else: else:
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})") print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
if enable_xattn:
print(f" XAttention: threshold={xattn_threshold}, use_bsa={xattn_use_bsa}")
print(f"{'='*60}\n") print(f"{'='*60}\n")
# 1. Initialize LLM # 1. Initialize LLM
@@ -121,7 +111,7 @@ def run_needle_test(
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
# Set sparse policy (can be used with or without offload) # Set sparse policy (can be used with or without offload)
if enable_minference or enable_quest or enable_xattn: if enable_minference or enable_quest:
llm_kwargs["sparse_policy"] = sparse_policy llm_kwargs["sparse_policy"] = sparse_policy
# MInference params (works with both GPU-only and offload mode) # MInference params (works with both GPU-only and offload mode)
@@ -130,11 +120,6 @@ def run_needle_test(
llm_kwargs["minference_vertical_size"] = minference_vertical llm_kwargs["minference_vertical_size"] = minference_vertical
llm_kwargs["minference_slash_size"] = minference_slash llm_kwargs["minference_slash_size"] = minference_slash
# XAttention params
if enable_xattn:
llm_kwargs["xattn_threshold"] = xattn_threshold
llm_kwargs["xattn_use_bsa"] = xattn_use_bsa
llm = LLM(model_path, **llm_kwargs) llm = LLM(model_path, **llm_kwargs)
# 2. Generate needle prompt # 2. Generate needle prompt
@@ -239,11 +224,6 @@ if __name__ == "__main__":
action="store_true", action="store_true",
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)" help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
) )
parser.add_argument(
"--enable-xattn",
action="store_true",
help="Enable XAttention sparse prefill with Block Sparse Attention"
)
parser.add_argument( parser.add_argument(
"--sparse-topk", "--sparse-topk",
type=int, type=int,
@@ -274,17 +254,6 @@ if __name__ == "__main__":
default=6096, default=6096,
help="Fixed slash_size (only used when budget=0)" help="Fixed slash_size (only used when budget=0)"
) )
parser.add_argument(
"--xattn-threshold",
type=float,
default=0.9,
help="XAttention block selection threshold (0-1, higher=more blocks)"
)
parser.add_argument(
"--xattn-no-bsa",
action="store_true",
help="Disable Block Sparse Attention (use FlashAttention fallback)"
)
parser.add_argument( parser.add_argument(
"--gpu-utilization", "--gpu-utilization",
type=float, type=float,
@@ -322,14 +291,11 @@ if __name__ == "__main__":
enable_cpu_offload=args.enable_offload, enable_cpu_offload=args.enable_offload,
enable_quest=args.enable_quest, enable_quest=args.enable_quest,
enable_minference=args.enable_minference, enable_minference=args.enable_minference,
enable_xattn=args.enable_xattn,
sparse_topk=args.sparse_topk, sparse_topk=args.sparse_topk,
sparse_threshold=args.sparse_threshold, sparse_threshold=args.sparse_threshold,
minference_budget=minference_budget, minference_budget=minference_budget,
minference_vertical=args.minference_vertical, minference_vertical=args.minference_vertical,
minference_slash=args.minference_slash, minference_slash=args.minference_slash,
xattn_threshold=args.xattn_threshold,
xattn_use_bsa=not args.xattn_no_bsa,
gpu_utilization=args.gpu_utilization, gpu_utilization=args.gpu_utilization,
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
verbose=True, verbose=True,

View File

@@ -0,0 +1,841 @@
"""
OffloadedTensor 统一测试套件
本文件整合了 OffloadedTensor 的所有测试,包括:
1. 基础功能验证
2. Chunked GEMM 测试
3. 同步分析
核心组件:
- OffloadedTensor: 虚拟 GPU Tensor支持透明 CPU/GPU 数据移动
- OffloadManager: LRU 缓存管理,支持同步/异步传输
- ChunkedOffloadLinear: 沿着 seqlen 维度分块的 Linear 层
"""
import torch
import torch.nn as nn
import weakref
import threading
import time
from typing import Optional, Dict, List, Tuple, Any
from dataclasses import dataclass
# ============================================================
# Part 1: 核心组件
# ============================================================
class OffloadedTensor(torch.Tensor):
"""
虚拟 GPU Tensor假装在 GPU 上,实际可能在 CPU
所有计算操作通过 __torch_dispatch__ 拦截,
在计算前自动加载数据到 GPU。
"""
@staticmethod
def __new__(cls, real_tensor: torch.Tensor, manager: 'OffloadManager', tensor_id: int):
device = torch.device("cuda", torch.cuda.current_device())
ret = torch.Tensor._make_wrapper_subclass(
cls,
real_tensor.size(),
strides=real_tensor.stride(),
dtype=real_tensor.dtype,
device=device,
requires_grad=real_tensor.requires_grad
)
ret._real_tensor = real_tensor
ret._manager = weakref.ref(manager)
ret._tensor_id = tensor_id
return ret
def __init__(self, real_tensor: torch.Tensor, manager: 'OffloadManager', tensor_id: int):
self._real_tensor = real_tensor
self._manager = weakref.ref(manager)
self._tensor_id = tensor_id
@property
def device(self) -> torch.device:
"""永远返回 CUDA device欺骗 PyTorch 的检查"""
return torch.device("cuda", torch.cuda.current_device())
def to(self, *args, **kwargs):
"""拦截 .to() 调用"""
device = None
if args and isinstance(args[0], torch.device):
device = args[0]
elif 'device' in kwargs:
device = kwargs['device']
if device and device.type == "cuda":
return self
return super().to(*args, **kwargs)
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
"""拦截所有 PyTorch 操作,自动加载数据"""
kwargs = kwargs or {}
manager = self._manager()
if manager:
manager.stats['dispatch_count'] += 1
# 特殊处理detach 返回 self
func_name = getattr(func, 'name', '')
if isinstance(func_name, str) and 'detach' in func_name.lower():
return self
# 解包 OffloadedTensor 为真实 tensor
def unwrap(t):
if isinstance(t, OffloadedTensor):
mgr = t._manager()
if mgr:
return mgr.get_gpu_tensor(t._real_tensor, t._tensor_id)
return t._real_tensor.cuda()
return t
new_args = torch.utils._pytree.tree_map(unwrap, args)
new_kwargs = torch.utils._pytree.tree_map(unwrap, kwargs)
result = func(*new_args, **new_kwargs)
return result
class OffloadManager:
"""
管理 tensor 的卸载和预取
特性:
- LRU 缓存管理 GPU 上的张量
- 支持同步/异步传输模式
- 完整的性能统计
"""
def __init__(
self,
device: str = "cuda",
offload_device: str = "cpu",
max_gpu_tensors: int = 2,
non_blocking: bool = False,
):
self.device = torch.device(device)
self.offload_device = torch.device(offload_device)
self._gpu_pool: Dict[int, torch.Tensor] = {}
self._cpu_storage: Dict[int, torch.Tensor] = {}
self._lock = threading.Lock()
self._tensor_id_counter = 0
self._max_gpu_tensors = max_gpu_tensors
self._access_order: List[int] = []
self.non_blocking = non_blocking
# 统计信息
self.stats = {
'load_count': 0,
'evict_count': 0,
'dispatch_count': 0,
'transfer_times_ms': [],
}
def _next_id(self) -> int:
tid = self._tensor_id_counter
self._tensor_id_counter += 1
return tid
def wrap(self, tensor: torch.Tensor) -> OffloadedTensor:
"""包装 tensor 为虚拟 GPU tensor"""
if isinstance(tensor, OffloadedTensor):
return tensor
tensor_id = self._next_id()
cpu_tensor = tensor.detach().to(self.offload_device)
self._cpu_storage[tensor_id] = cpu_tensor
return OffloadedTensor(cpu_tensor, self, tensor_id)
def get_gpu_tensor(self, real_tensor: torch.Tensor, tensor_id: int) -> torch.Tensor:
"""获取 GPU 上的数据LRU 缓存)"""
with self._lock:
self.stats['load_count'] += 1
if tensor_id in self._gpu_pool:
# 已在 GPU 上,更新 LRU
if tensor_id in self._access_order:
self._access_order.remove(tensor_id)
self._access_order.append(tensor_id)
return self._gpu_pool[tensor_id]
# LRU 驱逐
while len(self._gpu_pool) >= self._max_gpu_tensors:
if self._access_order:
evict_id = self._access_order.pop(0)
if evict_id in self._gpu_pool:
del self._gpu_pool[evict_id]
self.stats['evict_count'] += 1
else:
break
# 加载到 GPU
cpu_tensor = self._cpu_storage.get(tensor_id, real_tensor)
gpu_tensor = cpu_tensor.to(self.device, non_blocking=self.non_blocking)
self._gpu_pool[tensor_id] = gpu_tensor
self._access_order.append(tensor_id)
return gpu_tensor
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
transfer_times = self.stats['transfer_times_ms']
return {
'load_count': self.stats['load_count'],
'evict_count': self.stats['evict_count'],
'dispatch_count': self.stats['dispatch_count'],
'gpu_pool_size': len(self._gpu_pool),
'total_tensors': len(self._cpu_storage),
'total_transfer_time_ms': sum(transfer_times),
'avg_transfer_time_ms': sum(transfer_times) / len(transfer_times) if transfer_times else 0,
'transfer_times_ms': list(transfer_times),
}
class OffloadModuleWrapper(nn.Module):
"""包装 nn.Module实现参数级别的卸载"""
def __init__(self, module: nn.Module, manager: OffloadManager):
super().__init__()
self._original_module = module
self._manager = manager
self._wrap_parameters(module, "")
def _wrap_parameters(self, module: nn.Module, prefix: str):
"""递归包装模块的所有参数"""
for name, param in list(module.named_parameters(recurse=False)):
param.requires_grad_(False)
wrapped = self._manager.wrap(param.data)
delattr(module, name)
setattr(module, name, wrapped)
for child_name, child in list(module.named_children()):
self._wrap_parameters(child, prefix + child_name + ".")
def forward(self, *args, **kwargs):
return self._original_module(*args, **kwargs)
# ============================================================
# Part 2: 高级模块
# ============================================================
class ChunkedOffloadLinear(nn.Module):
"""
沿着 seqlen 维度分块的 Linear 层
将输入 [seqlen, in_features] 分成多个 chunks每个 chunk 独立进行 GEMM 计算。
weight 使用 OffloadedTensor按需加载到 GPU。
Args:
in_features: 输入特征维度
out_features: 输出特征维度
chunk_size: 每个 chunk 的大小
max_gpu_tensors: GPU 上最多缓存的 tensor 数量
non_blocking: 是否使用异步传输
"""
def __init__(
self,
in_features: int,
out_features: int,
chunk_size: int = 4096,
max_gpu_tensors: int = 2,
non_blocking: bool = False,
bias: bool = False,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.chunk_size = chunk_size
self.manager = OffloadManager(
max_gpu_tensors=max_gpu_tensors,
non_blocking=non_blocking
)
weight_tensor = torch.empty(out_features, in_features, dtype=torch.float16)
nn.init.xavier_uniform_(weight_tensor)
weight_tensor.requires_grad_(False)
self.weight = self.manager.wrap(weight_tensor)
self.bias = None
if bias:
self.bias = nn.Parameter(torch.empty(out_features))
def forward(self, x: torch.Tensor) -> torch.Tensor:
seqlen = x.shape[0]
if seqlen <= self.chunk_size:
return self._compute_chunk(x)
outputs = []
for start_idx in range(0, seqlen, self.chunk_size):
end_idx = min(start_idx + self.chunk_size, seqlen)
chunk = x[start_idx:end_idx]
chunk_output = self._compute_chunk(chunk)
outputs.append(chunk_output)
return torch.cat(outputs, dim=0)
def _compute_chunk(self, chunk: torch.Tensor) -> torch.Tensor:
return torch.nn.functional.linear(chunk, self.weight, self.bias)
# ============================================================
# 辅助函数
# ============================================================
def calculate_memory(
seqlen: int,
in_features: int,
out_features: int,
dtype: torch.dtype = torch.float16,
) -> Dict[str, float]:
"""计算显存占用MB"""
element_size = torch.finfo(dtype).bits / 8
activation = seqlen * in_features * element_size / (1024 ** 2)
weight = in_features * out_features * element_size / (1024 ** 2)
output = seqlen * out_features * element_size / (1024 ** 2)
total = activation + weight + output
peak = max(activation, output) + weight
return {
'activation_mb': activation,
'weight_mb': weight,
'output_mb': output,
'total_mb': total,
'peak_mb': peak,
}
def run_benchmark(
layer: nn.Module,
input_tensor: torch.Tensor,
num_runs: int = 3,
) -> Dict[str, float]:
"""运行性能测试"""
torch.cuda.synchronize()
# Warmup
with torch.no_grad():
_ = layer(input_tensor)
torch.cuda.synchronize()
# Benchmark
start_time = time.time()
for _ in range(num_runs):
with torch.no_grad():
output = layer(input_tensor)
torch.cuda.synchronize()
elapsed = time.time() - start_time
avg_time = elapsed / num_runs
total_elements = input_tensor.numel() + output.numel()
throughput = total_elements / avg_time / 1e6
return {
'avg_time_ms': avg_time * 1000,
'throughput_meps': throughput,
}
# ============================================================
# Part 3: 测试套件 - 功能测试
# ============================================================
def test_1_basic_offloaded_tensor():
"""测试 OffloadedTensor 基本功能"""
print("\n=== Test 1: Basic OffloadedTensor ===")
if not torch.cuda.is_available():
print("CUDA not available, skipping")
return
manager = OffloadManager(max_gpu_tensors=2)
t1 = torch.randn(4, 4)
t2 = torch.randn(4, 4)
t3 = torch.randn(4, 4)
w1 = manager.wrap(t1)
w2 = manager.wrap(t2)
w3 = manager.wrap(t3)
print(f"✓ Created OffloadedTensors")
print(f" w1.device: {w1.device}")
print(f" w2.device: {w2.device}")
assert w1.device.type == "cuda"
print(f"✓ is_cuda check passed")
result = w1 + w2
print(f"✓ Addition works: {result.shape}")
stats = manager.get_stats()
print(f"✓ Manager stats: {stats}")
print("PASSED\n")
def test_2_mlp_with_offload():
"""测试 MLP 模型使用 OffloadedTensor"""
print("\n=== Test 2: MLP with OffloadedTensor ===")
if not torch.cuda.is_available():
print("CUDA not available, skipping")
return
class SimpleMLP(nn.Module):
def __init__(self, hidden_size=128, intermediate_size=256):
super().__init__()
self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x):
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
return self.down_proj(nn.functional.silu(gate) * up)
hidden_size = 128
intermediate_size = 256
batch_size, seq_len = 2, 4
input_ids = torch.randn(batch_size, seq_len, hidden_size, device="cuda")
model_original = SimpleMLP(hidden_size, intermediate_size)
model_original.to("cuda")
model_original.eval()
with torch.no_grad():
expected = model_original(input_ids)
state_dict = model_original.state_dict()
model = SimpleMLP(hidden_size, intermediate_size)
model.load_state_dict(state_dict)
model.eval()
offloaded_model, manager = apply_offload_to_model(model, max_gpu_tensors=2)
offloaded_model.eval()
with torch.no_grad():
output = offloaded_model(input_ids)
print(f"✓ Forward pass completed: {output.shape}")
stats = manager.get_stats()
print(f"✓ Offload stats: {stats}")
diff = (output - expected).abs().max().item()
print(f"✓ Output correctness: max diff = {diff:.6f}")
assert diff < 1e-5
print("PASSED\n")
def apply_offload_to_model(model: nn.Module, max_gpu_tensors: int = 2):
"""应用卸载到模型的所有参数"""
manager = OffloadManager(max_gpu_tensors=max_gpu_tensors)
wrapper = OffloadModuleWrapper(model, manager)
return wrapper, manager
def test_3_lru_eviction():
"""测试 LRU 驱逐机制"""
print("\n=== Test 3: LRU Eviction ===")
if not torch.cuda.is_available():
print("CUDA not available, skipping")
return
manager = OffloadManager(max_gpu_tensors=2)
tensors = [torch.randn(2, 2) for _ in range(4)]
wrapped = [manager.wrap(t) for t in tensors]
print(f"✓ Created {len(wrapped)} OffloadedTensors")
print(f" GPU pool capacity: {manager._max_gpu_tensors}")
_ = wrapped[0] + wrapped[1]
stats = manager.get_stats()
print(f"✓ After accessing t1, t2: GPU pool = {stats['gpu_pool_size']}")
_ = wrapped[2] + wrapped[2]
stats = manager.get_stats()
print(f"✓ After accessing t3: GPU pool = {stats['gpu_pool_size']}, evicted = {stats['evict_count']}")
_ = wrapped[3] + wrapped[3]
stats = manager.get_stats()
print(f"✓ After accessing t4: GPU pool = {stats['gpu_pool_size']}, evicted = {stats['evict_count']}")
assert stats['evict_count'] >= 1
print("PASSED\n")
def test_4_correctness():
"""测试输出正确性"""
print("\n=== Test 4: Correctness Check ===")
if not torch.cuda.is_available():
print("CUDA not available, skipping")
return
in_features = 512
out_features = 1024
seqlen = 4096
chunk_size = 1024
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
# 创建标准层并保存权重
linear = nn.Linear(in_features, out_features, bias=False)
linear.to("cuda", dtype=torch.float16)
linear.eval()
with torch.no_grad():
expected = linear(x)
print(f"✓ Got expected output")
# 创建 ChunkedOffloadLinear使用相同的权重
chunked_layer = ChunkedOffloadLinear(in_features, out_features, chunk_size, max_gpu_tensors=2)
# 复制权重到 chunked_layer
with torch.no_grad():
weight_data = linear.weight.data.cpu()
chunked_layer.manager._cpu_storage[0] = weight_data
with torch.no_grad():
actual = chunked_layer(x)
print(f"✓ Got actual output")
diff = (actual - expected).abs().max().item()
print(f"✓ Max difference: {diff:.6f}")
assert diff < 1e-5
print("PASSED\n")
# ============================================================
# Part 3: 测试套件 - 性能测试
# ============================================================
def test_5_memory_analysis():
"""分析内存占用"""
print("\n=== Test 5: Memory Analysis ===")
in_features = 4096
out_features = 12244
chunk_size = 4096
seqlens = [4096, 16384, 65536, 131072]
print(f"\nMemory Analysis (in={in_features}, out={out_features}, chunk={chunk_size}):")
print(f"{'Seqlen':>10} | {'Activation':>12} | {'Weight':>12} | {'Output':>12} | {'Peak':>12} | {'Chunked':>12}")
print("-" * 90)
for seqlen in seqlens:
full = calculate_memory(seqlen, in_features, out_features)
chunked = calculate_memory(chunk_size, in_features, out_features)
print(f"{seqlen:>10} | "
f"{full['activation_mb']:>10.1f}MB | "
f"{full['weight_mb']:>10.1f}MB | "
f"{full['output_mb']:>10.1f}MB | "
f"{full['peak_mb']:>10.1f}MB | "
f"{chunked['peak_mb']:>10.1f}MB")
print("\n✓ Chunked offload 显存占用恒定,与序列长度无关!")
print("PASSED\n")
def test_6_long_sequence():
"""测试超长序列"""
print("\n=== Test 6: Long Sequence (128K tokens) ===")
if not torch.cuda.is_available():
print("CUDA not available, skipping")
return
in_features = 4096
out_features = 12244
seqlen = 128 * 1024
chunk_size = 4096
full = calculate_memory(seqlen, in_features, out_features)
chunked = calculate_memory(chunk_size, in_features, out_features)
print(f"Memory Comparison:")
print(f" Full: {full['peak_mb']:.1f} MB")
print(f" Chunked: {chunked['peak_mb']:.1f} MB")
print(f" Savings: {(1 - chunked['peak_mb']/full['peak_mb'])*100:.1f}%")
layer = ChunkedOffloadLinear(in_features, out_features, chunk_size, max_gpu_tensors=1)
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
with torch.no_grad():
start = time.time()
output = layer(x)
torch.cuda.synchronize()
elapsed = (time.time() - start) * 1000
print(f"✓ Forward pass: {output.shape}")
print(f" Time: {elapsed:.1f} ms")
print(f" Throughput: {seqlen/elapsed/1e3:.1f}K tokens/sec")
stats = layer.manager.get_stats()
print(f"✓ Chunks processed: {seqlen // chunk_size}")
print(f"✓ Load count: {stats['load_count']}")
print("PASSED\n")
def test_7_performance_comparison():
"""性能对比测试"""
print("\n=== Test 7: Performance Comparison ===")
if not torch.cuda.is_available():
print("CUDA not available, skipping")
return
in_features = 4096
out_features = 12244
seqlen = 16384
chunk_size = 4096
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
linear = nn.Linear(in_features, out_features, bias=False).cuda().half().eval()
standard_stats = run_benchmark(linear, x, num_runs=5)
print(f"✓ Standard Linear: {standard_stats['avg_time_ms']:.1f} ms")
chunked_layer = ChunkedOffloadLinear(in_features, out_features, chunk_size, max_gpu_tensors=1)
chunked_stats = run_benchmark(chunked_layer, x, num_runs=5)
print(f"✓ ChunkedOffloadLinear: {chunked_stats['avg_time_ms']:.1f} ms")
speedup = standard_stats['avg_time_ms'] / chunked_stats['avg_time_ms']
print(f"✓ Speedup: {speedup:.2f}x")
print("PASSED\n")
def test_8_transformers_layer():
"""测试实际 transformers 权重"""
print("\n=== Test 8: Transformers Layer Test ===")
try:
from transformers import AutoModelForCausalLM
except ImportError:
print("transformers not installed, skipping")
return
if not torch.cuda.is_available():
print("CUDA not available, skipping")
return
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
try:
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
trust_remote_code=True,
)
model.eval()
model.to("cuda")
except Exception as e:
print(f"Failed to load model: {e}")
return
down_proj = model.model.layers[0].mlp.down_proj
print(f"✓ Got layer: {down_proj.in_features} -> {down_proj.out_features}")
batch_size, seq_len = 1, 4
test_input = torch.randn(batch_size, seq_len, down_proj.in_features, device="cuda", dtype=torch.float16)
with torch.no_grad():
normal_output = down_proj(test_input)
print(f"✓ Normal inference: {normal_output.shape}")
import copy
test_linear = nn.Linear(down_proj.in_features, down_proj.out_features, bias=False)
test_linear.load_state_dict(copy.deepcopy(down_proj.state_dict()))
test_linear.to("cuda", dtype=torch.float16)
test_linear.eval()
manager = OffloadManager(max_gpu_tensors=2)
offloaded_layer = OffloadModuleWrapper(test_linear, manager)
with torch.no_grad():
offload_output = offloaded_layer(test_input)
print(f"✓ Offload inference: {offload_output.shape}")
stats = manager.get_stats()
print(f"✓ Stats: {stats}")
diff = (offload_output - normal_output).abs().max().item()
print(f"✓ Max diff: {diff:.6f}")
assert diff < 1e-5
print("PASSED\n")
# ============================================================
# Part 3: 测试套件 - 同步分析
# ============================================================
def test_9_sync_behavior_analysis():
"""分析同步传输 vs 异步传输"""
print("\n=== Test 9: Sync Behavior Analysis ===")
if not torch.cuda.is_available():
print("CUDA not available, skipping")
return
in_features = 4096
out_features = 12244
seqlen = 16384
chunk_size = 4096
print(f"Config: in={in_features}, out={out_features}, seqlen={seqlen}, chunk={chunk_size}")
print(f"Num chunks: {seqlen // chunk_size}")
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
# 同步版本
print(f"\n--- 同步传输 (non_blocking=False) ---")
layer_sync = ChunkedOffloadLinear(in_features, out_features, chunk_size, non_blocking=False)
with torch.no_grad():
start = time.time()
_ = layer_sync(x)
torch.cuda.synchronize()
sync_time_ms = (time.time() - start) * 1000
stats_sync = layer_sync.manager.get_stats()
print(f"总时间: {sync_time_ms:.2f} ms")
print(f"传输时间: {stats_sync['total_transfer_time_ms']:.2f} ms")
print(f"计算时间: {sync_time_ms - stats_sync['total_transfer_time_ms']:.2f} ms")
print(f"加载次数: {stats_sync['load_count']}")
# 异步版本
print(f"\n--- 异步传输 (non_blocking=True) ---")
layer_async = ChunkedOffloadLinear(in_features, out_features, chunk_size, non_blocking=True)
with torch.no_grad():
start = time.time()
_ = layer_async(x)
torch.cuda.synchronize()
async_time_ms = (time.time() - start) * 1000
stats_async = layer_async.manager.get_stats()
print(f"总时间: {async_time_ms:.2f} ms")
print(f"传输时间: {stats_async['total_transfer_time_ms']:.2f} ms")
print(f"计算时间: {async_time_ms - stats_async['total_transfer_time_ms']:.2f} ms")
print(f"加载次数: {stats_async['load_count']}")
# 对比
print(f"\n--- 对比 ---")
print(f"总加速比: {sync_time_ms / async_time_ms:.2f}x")
if stats_async['total_transfer_time_ms'] > 0:
print(f"传输加速比: {stats_sync['total_transfer_time_ms'] / stats_async['total_transfer_time_ms']:.2f}x")
print("\n关键发现:")
print(f" 1. 同步传输阻塞 CPU 线程")
print(f" 2. 异步传输可提高吞吐量")
print(f" 3. 首次运行包含 JIT 编译开销")
print("PASSED\n")
def test_10_profiler_analysis():
"""使用 Profiler 分析内核执行"""
print("\n=== Test 10: Profiler Analysis ===")
if not torch.cuda.is_available():
print("CUDA not available, skipping")
return
in_features = 4096
out_features = 12244
seqlen = 16384
chunk_size = 4096
layer = ChunkedOffloadLinear(in_features, out_features, chunk_size)
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p:
with torch.no_grad():
_ = layer(x)
torch.cuda.synchronize()
kernel_counts = {}
for event in p.key_averages():
if event.device_type == torch.profiler.DeviceType.CUDA:
name = event.key
kernel_counts[name] = kernel_counts.get(name, 0) + 1
print(f"内核调用统计:")
print(f"{'内核类型':<50} {'调用次数':<10}")
print("-" * 60)
for name, count in sorted(kernel_counts.items(), key=lambda x: -x[1])[:15]:
name_short = name[:48]
print(f"{name_short:<50} {count:<10}")
memcpy_count = sum(count for name, count in kernel_counts.items() if 'memcpy' in name.lower())
print(f"\n分析:")
print(f" - 总共 {len(kernel_counts)} 种不同的 CUDA 内核")
print(f" - 总调用次数: {sum(kernel_counts.values())}")
print(f" - 内存拷贝: {memcpy_count}")
print("PASSED\n")
# ============================================================
# 主测试入口
# ============================================================
def main():
"""运行所有测试"""
print("=" * 70)
print("OffloadedTensor 统一测试套件")
print("=" * 70)
# 功能测试
print("\n" + "=" * 70)
print("功能测试 (Tests 1-4)")
print("=" * 70)
test_1_basic_offloaded_tensor()
test_2_mlp_with_offload()
test_3_lru_eviction()
test_4_correctness()
# 性能测试
print("\n" + "=" * 70)
print("性能测试 (Tests 5-8)")
print("=" * 70)
test_5_memory_analysis()
test_6_long_sequence()
test_7_performance_comparison()
test_8_transformers_layer()
# 同步分析
print("\n" + "=" * 70)
print("同步分析 (Tests 9-10)")
print("=" * 70)
test_9_sync_behavior_analysis()
test_10_profiler_analysis()
print("=" * 70)
print("所有测试完成!")
print("=" * 70)
if __name__ == "__main__":
main()

View File

@@ -1,244 +0,0 @@
"""
Test: Compare xattn_estimate vs xattn_estimate_chunked
Verify that chunked estimation with EXTERNAL chunking produces the same mask
as standard estimation. This ensures the chunked version can be used in
chunked prefill scenarios without accuracy loss.
Usage:
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
python tests/test_xattn_estimate_chunked.py
"""
import sys
import traceback
import torch
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
# ============================================================
# Configuration
# ============================================================
# Configuration for xattn_estimate_chunked consistency test.
# Key requirements for 100% match:
# 1. Use matching chunk_size for both standard and chunked versions
# 2. Use same random seed for reproducibility
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
# floating point precision in cumulative sum calculations.
BLOCK_SIZE = 64
STRIDE = 4
THRESHOLD = 0.9
CHUNK_SIZE = 4096 # External chunking size
# Test sequence lengths
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
# ============================================================
# Utility Functions
# ============================================================
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
"""Compare two masks and report differences."""
if mask1.shape != mask2.shape:
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
return False
diff = (mask1 != mask2).sum().item()
total = mask1.numel()
match_rate = (total - diff) / total * 100
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
if diff > 0:
diff_indices = torch.where(mask1 != mask2)
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
return diff == 0
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
"""
Run xattn_estimate_chunked with EXTERNAL chunking.
This simulates how chunked prefill should be used in practice.
"""
batch_size, num_heads, q_len, head_dim = query.shape
_, _, k_len, _ = key.shape
q_block_num = (q_len + block_size - 1) // block_size
k_block_num = (k_len + block_size - 1) // block_size
# If Q fits in one chunk, call directly
if q_len <= chunk_size:
return xattn_estimate_chunked(
query, key,
q_start_pos=0,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# External chunking: split Q and call for each chunk
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
print(f" External chunking: {num_q_chunks} chunks")
combined_attn_sum = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=query.dtype, device=query.device
)
combined_mask = torch.zeros(
batch_size, num_heads, q_block_num, k_block_num,
dtype=torch.bool, device=query.device
)
q_block_offset = 0
for q_chunk_idx in range(num_q_chunks):
q_chunk_start = q_chunk_idx * chunk_size
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
# For causal attention, K accumulates up to current Q position
# q_start_pos=0 means Q starts at position 0 in the full sequence
# K is [0, q_chunk_end) for causal attention
k_end = q_chunk_end
k_chunk = key[:, :, :k_end, :]
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
q_chunk, k_chunk,
q_start_pos=q_chunk_start,
block_size=block_size,
stride=stride,
threshold=threshold,
use_triton=True,
chunk_size=chunk_size,
)
# Place chunk results into combined output
chunk_q_blocks = mask_chunk.shape[2]
chunk_k_blocks = mask_chunk.shape[3]
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
q_block_offset += chunk_q_blocks
return combined_attn_sum, combined_mask
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
"""Test a single sequence length."""
print(f"\nTesting seq_len={seq_len}")
print("=" * 60)
# Generate random Q/K
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
# Run standard xattn_estimate
print("[1] Running standard xattn_estimate...")
try:
attn_sum_std, mask_std = xattn_estimate(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
use_triton=True,
causal=True,
)
density_std = mask_std.float().mean().item()
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Run chunked xattn_estimate with EXTERNAL chunking
print("[2] Running chunked xattn_estimate (external chunking)...")
try:
attn_sum_chunked, mask_chunked = run_chunked_externally(
query, key,
block_size=BLOCK_SIZE,
stride=STRIDE,
threshold=THRESHOLD,
chunk_size=CHUNK_SIZE,
)
density_chunked = mask_chunked.float().mean().item()
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
except Exception as e:
print(f" ERROR: {e}")
traceback.print_exc()
return False
# Compare results
print("[3] Comparing results...")
chunked_q_blocks = mask_chunked.shape[2]
chunked_k_blocks = mask_chunked.shape[3]
# Extract comparable region from standard mask
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
# Compare masks
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
# Compare attn_sums
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
print(f" Attn sum max diff: {attn_diff:.6f}")
else:
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
# Clean up GPU memory
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
torch.cuda.empty_cache()
return masks_match
# ============================================================
# Main Test
# ============================================================
if __name__ == "__main__":
print("XAttention Chunked vs Standard Test")
print("=" * 60)
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
print(f"External chunk_size={CHUNK_SIZE}")
print()
# Check CUDA availability
if not torch.cuda.is_available():
print("CUDA not available!")
sys.exit(1)
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
print("✓ xattn_estimate imported")
print("✓ xattn_estimate_chunked imported")
# Run tests
all_passed = True
results = []
for seq_len in TEST_SEQ_LENS:
passed = test_single_seq_len(seq_len)
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
results.append((seq_len, chunks, passed))
if not passed:
all_passed = False
# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
for seq_len, chunks, passed in results:
status = "PASSED" if passed else "FAILED"
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
print("=" * 60)
if all_passed:
print("ALL TESTS PASSED!")
sys.exit(0)
else:
print("SOME TESTS FAILED!")
sys.exit(1)