Compare commits
4 Commits
tzj/layer-
...
tzj/vs_off
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b8c00399af | ||
|
|
13586e689b | ||
|
|
e72725c12b | ||
|
|
cfb188c34a |
@@ -1,158 +0,0 @@
|
|||||||
---
|
|
||||||
allowed-tools: Bash(CUDA_VISIBLE_DEVICES=*), Bash(PYTHONPATH=*), Bash(python*), Bash(git*), Bash(rm*), Bash(ls*), Bash(cat*), Bash(nvidia-smi*), Read, Edit, Write, Glob, Grep, TodoWrite, Task
|
|
||||||
argument-hint: --gpu <id> [--no-interrupt]
|
|
||||||
description: Execute task_plan.md refactoring with specified GPU, optionally without user interruption
|
|
||||||
---
|
|
||||||
|
|
||||||
# Execute Task Plan (exec-plan)
|
|
||||||
|
|
||||||
按照 `task_plan.md` 的要求执行代码重构,确保计划中的最终目标圆满实现。
|
|
||||||
|
|
||||||
## 参数说明
|
|
||||||
|
|
||||||
命令格式: `/exec-plan --gpu <id> [--no-interrupt]`
|
|
||||||
|
|
||||||
| 参数 | 说明 | 示例 |
|
|
||||||
|------|------|------|
|
|
||||||
| `--gpu <id>` | **必需**。指定可用的 GPU ID,只能使用此 GPU 进行调试 | `--gpu 0`, `--gpu 2` |
|
|
||||||
| `--no-interrupt` | 可选。禁止中断执行,遇到问题不与用户交互,自动解决或跳过 | `--no-interrupt` |
|
|
||||||
|
|
||||||
## 当前参数
|
|
||||||
|
|
||||||
```
|
|
||||||
$ARGUMENTS
|
|
||||||
```
|
|
||||||
|
|
||||||
## 执行前准备
|
|
||||||
|
|
||||||
### 1. 解析参数
|
|
||||||
|
|
||||||
从 `$ARGUMENTS` 中解析:
|
|
||||||
- `GPU_ID`: 从 `--gpu <id>` 或 `-g <id>` 提取
|
|
||||||
- `NO_INTERRUPT`: 是否存在 `--no-interrupt` 或 `-n` 标志
|
|
||||||
|
|
||||||
### 2. 参数验证
|
|
||||||
|
|
||||||
**必须验证**:
|
|
||||||
- GPU_ID 必须是有效的数字
|
|
||||||
- 运行 `nvidia-smi -i <GPU_ID>` 验证 GPU 存在
|
|
||||||
|
|
||||||
### 3. 读取 task_plan.md
|
|
||||||
|
|
||||||
读取项目根目录下的 `task_plan.md` 文件,理解:
|
|
||||||
- 总体目标
|
|
||||||
- 分阶段计划 (Phase 1, 2, 3...)
|
|
||||||
- 文件修改清单
|
|
||||||
- 风险和注意事项
|
|
||||||
- 测试计划
|
|
||||||
|
|
||||||
## 执行流程
|
|
||||||
|
|
||||||
### Step 1: 创建执行计划
|
|
||||||
|
|
||||||
使用 TodoWrite 工具创建详细的执行计划,包括:
|
|
||||||
- 从 task_plan.md 提取的所有 Phase
|
|
||||||
- 每个 Phase 的子任务
|
|
||||||
- 测试验证步骤
|
|
||||||
|
|
||||||
### Step 2: 按 Phase 执行重构
|
|
||||||
|
|
||||||
对于 task_plan.md 中的每个 Phase:
|
|
||||||
|
|
||||||
1. **读取当前代码**: 使用 Read/Grep 理解现有实现
|
|
||||||
2. **实施修改**: 使用 Edit/Write 进行代码修改
|
|
||||||
3. **验证修改**: 运行相关测试
|
|
||||||
|
|
||||||
### Step 3: 运行测试验证
|
|
||||||
|
|
||||||
执行 task_plan.md 中定义的测试计划,验证重构成功。
|
|
||||||
|
|
||||||
## GPU 限制规则
|
|
||||||
|
|
||||||
**严格限制**: 只能使用指定的 GPU,所有涉及 GPU 的命令必须加 `CUDA_VISIBLE_DEVICES` 前缀:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 正确
|
|
||||||
CUDA_VISIBLE_DEVICES=$GPU_ID PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python test.py
|
|
||||||
|
|
||||||
# 错误 - 禁止使用其他 GPU
|
|
||||||
python test.py # 可能使用默认 GPU 0
|
|
||||||
CUDA_VISIBLE_DEVICES=0,1 python test.py # 使用多个 GPU
|
|
||||||
```
|
|
||||||
|
|
||||||
## 中断模式规则
|
|
||||||
|
|
||||||
### 当 `--no-interrupt` 生效时
|
|
||||||
|
|
||||||
遇到以下情况**不停下来询问用户**,而是:
|
|
||||||
|
|
||||||
| 情况 | 处理方式 |
|
|
||||||
|------|----------|
|
|
||||||
| 测试失败 | 记录失败原因,尝试自动修复,继续下一步 |
|
|
||||||
| 代码冲突 | 尝试合理解决,记录解决方案 |
|
|
||||||
| 不确定的实现细节 | 选择最合理的方案继续 |
|
|
||||||
| 执行错误 | 分析错误,尝试修复,记录问题 |
|
|
||||||
|
|
||||||
**自动决策原则**:
|
|
||||||
1. 优先保证功能正确性
|
|
||||||
2. 遵循现有代码风格
|
|
||||||
3. 选择简单直接的实现
|
|
||||||
4. 记录所有自动决策到 `progress.md`
|
|
||||||
|
|
||||||
### 当未指定 `--no-interrupt` 时
|
|
||||||
|
|
||||||
遇到以下情况**可以询问用户**:
|
|
||||||
- 多个实现方案需要选择
|
|
||||||
- 测试持续失败无法自动修复
|
|
||||||
- 发现 task_plan.md 中的问题或矛盾
|
|
||||||
|
|
||||||
## 执行记录
|
|
||||||
|
|
||||||
### 进度文件: progress.md
|
|
||||||
|
|
||||||
实时更新 `progress.md` 记录:
|
|
||||||
|
|
||||||
```markdown
|
|
||||||
## 执行进度
|
|
||||||
|
|
||||||
### Phase X: [名称]
|
|
||||||
- 状态: [进行中/完成/失败]
|
|
||||||
- 开始时间: [时间]
|
|
||||||
- 完成时间: [时间]
|
|
||||||
- 修改文件: [文件列表]
|
|
||||||
- 自动决策: [如果有]
|
|
||||||
- 问题记录: [如果有]
|
|
||||||
```
|
|
||||||
|
|
||||||
### 发现记录: findings.md
|
|
||||||
|
|
||||||
记录执行过程中的重要发现到 `findings.md`。
|
|
||||||
|
|
||||||
## 示例用法
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 使用 GPU 2,允许中断
|
|
||||||
/exec-plan --gpu 2
|
|
||||||
|
|
||||||
# 使用 GPU 0,不中断执行
|
|
||||||
/exec-plan --gpu 0 --no-interrupt
|
|
||||||
|
|
||||||
# 简短形式
|
|
||||||
/exec-plan -g 1 -n
|
|
||||||
```
|
|
||||||
|
|
||||||
## 完成标准
|
|
||||||
|
|
||||||
执行完成后,确保:
|
|
||||||
|
|
||||||
1. **所有 Phase 完成**: task_plan.md 中的所有 Phase 都已实施
|
|
||||||
2. **测试通过**: task_plan.md 中的测试计划全部通过
|
|
||||||
3. **代码质量**: 修改符合项目代码规范
|
|
||||||
4. **文档更新**: progress.md 包含完整执行记录
|
|
||||||
|
|
||||||
## 重要约束
|
|
||||||
|
|
||||||
1. **GPU 隔离**: 绝对不能使用指定 GPU 以外的设备
|
|
||||||
2. **遵循计划**: 严格按照 task_plan.md 执行,不做计划外的修改
|
|
||||||
3. **渐进式修改**: 每个 Phase 完成后验证,而不是最后一起验证
|
|
||||||
4. **回滚准备**: 重大修改前考虑是否需要 git commit 保存点
|
|
||||||
6
.gitmodules
vendored
6
.gitmodules
vendored
@@ -1,4 +1,4 @@
|
|||||||
[submodule "3rdparty/Block-Sparse-Attention"]
|
[submodule "3rdparty/Block-SparseAttention"]
|
||||||
path = 3rdparty/Block-Sparse-Attention
|
path = 3rdparty/Block-SparseAttention
|
||||||
url = git@github.com:Zijie-Tian/Block-Sparse-Attention.git
|
url = https://github.com/Zijie-Tian/Block-SparseAttention.git
|
||||||
branch = tzj/minference
|
branch = tzj/minference
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH python tests/test_needle.py
|
|||||||
| [`docs/xattention_integration.md`](docs/xattention_integration.md) | XAttention integration guide: algorithm, implementation, design decisions, and testing |
|
| [`docs/xattention_integration.md`](docs/xattention_integration.md) | XAttention integration guide: algorithm, implementation, design decisions, and testing |
|
||||||
| [`docs/xattention_analysis.md`](docs/xattention_analysis.md) | XAttention algorithm analysis: chunked estimation, block sparse attention, integration design |
|
| [`docs/xattention_analysis.md`](docs/xattention_analysis.md) | XAttention algorithm analysis: chunked estimation, block sparse attention, integration design |
|
||||||
| [`docs/development_notes.md`](docs/development_notes.md) | Development notes and scratchpad for ongoing work |
|
| [`docs/development_notes.md`](docs/development_notes.md) | Development notes and scratchpad for ongoing work |
|
||||||
|
| [`docs/chunked_prefill_analysis.md`](docs/chunked_prefill_analysis.md) | **NEW**: Chunked prefill for ultra-long sequences (1M+), memory analysis, MLP activation breakdown, implementation guide |
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
|
|||||||
1055
docs/chunked_prefill_analysis.md
Normal file
1055
docs/chunked_prefill_analysis.md
Normal file
File diff suppressed because it is too large
Load Diff
354
docs/chunked_prefill_integration_plan.md
Normal file
354
docs/chunked_prefill_integration_plan.md
Normal file
@@ -0,0 +1,354 @@
|
|||||||
|
# Chunked Prefill 集成计划
|
||||||
|
|
||||||
|
**目标**: 将 tzj/minference 分支的 chunked prefill 机制移植到 tzj/vs_offload 分支
|
||||||
|
|
||||||
|
**创建日期**: 2026-01-18
|
||||||
|
**基础分支**: `tzj/vs_offload`
|
||||||
|
**源分支**: `tzj/minference`
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 目标
|
||||||
|
|
||||||
|
在 tzj/vs_offload 分支上实现 chunked prefill + layerwise offload 机制,支持在 24GB RTX 3090 上运行任意长度的推理(4M, 8M, 16M+ tokens)。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 核心问题
|
||||||
|
|
||||||
|
### tzj/vs_offload 分支的局限性
|
||||||
|
|
||||||
|
当前 tzj/vs_offload 分支的 GPU ring buffer 按 `max_seq_len` 分配,导致 GPU 内存随序列长度线性增长:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 当前设计
|
||||||
|
self.layer_k_cache = torch.zeros(
|
||||||
|
num_kv_buffers, # e.g., 4
|
||||||
|
max_seq_len, # e.g., 131072 tokens
|
||||||
|
kv_heads,
|
||||||
|
head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
**问题**:
|
||||||
|
- GPU 内存需求 ~ `max_seq_len × 4 × 8 × 128 × 2 bytes`
|
||||||
|
- 对于超长序列不可行:
|
||||||
|
- 4M tokens → ~64 GB GPU 内存 ❌
|
||||||
|
- 8M tokens → ~128 GB GPU 内存 ❌
|
||||||
|
|
||||||
|
### 解决方案:Block-Based 设计
|
||||||
|
|
||||||
|
tzj/minference 分支采用 block-based 设计,GPU 内存固定:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Block-based 设计
|
||||||
|
self.k_cache_gpu = torch.zeros(
|
||||||
|
num_gpu_blocks, # e.g., 2
|
||||||
|
block_size, # e.g., 1024 tokens (固定!)
|
||||||
|
kv_heads,
|
||||||
|
head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
|
# GPU 内存: ~4 MB (固定,不随序列长度增长)
|
||||||
|
```
|
||||||
|
|
||||||
|
**优势**:
|
||||||
|
- GPU 内存固定(~1.6 GB),不随序列长度增长
|
||||||
|
- 24GB RTX 3090 可运行 4M+ tokens
|
||||||
|
- 通过 chunked prefill 分块处理超长序列
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 内存布局对比
|
||||||
|
|
||||||
|
| 组件 | tzj/vs_offload | tzj/minference | 说明 |
|
||||||
|
|------|---------------|----------------|------|
|
||||||
|
| **GPU Ring Buffer** | `[num_kv_buffers, max_seq_len, ...]` | `[num_gpu_blocks, block_size, ...]` | minference 无 layer 维度 |
|
||||||
|
| **GPU 内存** | ~2.15 GB (128K) → ~64 GB (4M) | ~4 MB (固定) | minference 节省显著 |
|
||||||
|
| **Prefill Buffer** | ❌ 无 | ✅ `[num_layers, block_size, ...]` | minference 独有 |
|
||||||
|
| **Pipeline Buffers** | ❌ 无 | ✅ 双缓冲区 `[blocks, block_size, ...]` | minference 独有 |
|
||||||
|
| **CPU Cache** | `[num_layers, num_cpu_blocks, block_size, ...]` | 相同 | **一致** |
|
||||||
|
|
||||||
|
### 序列长度支持对比
|
||||||
|
|
||||||
|
| 序列长度 | vs_offload GPU 内存 | minference GPU 内存 | RTX 3090 (24GB) |
|
||||||
|
|----------|-------------------|---------------------|-----------------|
|
||||||
|
| 128K tokens | ~2.15 GB | ~4 MB | ✅ 两者均可 |
|
||||||
|
| 1M tokens | ~16 GB | ~4 MB | ✅ 两者均可 |
|
||||||
|
| **4M tokens** | **~64 GB** ❌ | **~4 MB** ✅ | **仅 minference 可行** |
|
||||||
|
| **8M tokens** | **~128 GB** ❌ | **~4 MB** ✅ | **仅 minference 可行** |
|
||||||
|
| **16M+ tokens** | **~256 GB+** ❌ | **~4 MB** ✅ | **仅 minference 可行** |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 关键设计原则
|
||||||
|
|
||||||
|
1. **Block-Based 设计**:按 `block_size` (1024 tokens) 组织,支持 chunked prefill
|
||||||
|
2. **GPU 内存固定**:不随序列长度增长,是 constant factor
|
||||||
|
3. **CPU 内存线性缩放**:`num_cpu_blocks = ceil(seq_len / block_size)`
|
||||||
|
4. **Unified Ring Buffer**:无 layer 维度,所有层共享 slots
|
||||||
|
5. **完全并行 offload**:per-layer buffer 最大化 PCIe 带宽
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 统一内存布局设计
|
||||||
|
|
||||||
|
### GPU Memory Layout
|
||||||
|
|
||||||
|
```python
|
||||||
|
class OffloadEngine:
|
||||||
|
# 1. Unified Ring Buffer - Block-based,无 layer 维度
|
||||||
|
self.k_cache_gpu = torch.zeros(
|
||||||
|
num_gpu_blocks, # e.g., 2
|
||||||
|
block_size, # e.g., 1024
|
||||||
|
kv_heads,
|
||||||
|
head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
) # ~4 MB (固定)
|
||||||
|
|
||||||
|
# 2. Per-layer Prefill Buffer - 完全并行 offload
|
||||||
|
self.prefill_k_buffer = torch.zeros(
|
||||||
|
num_layers, block_size, kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
) # ~58 MB (固定)
|
||||||
|
|
||||||
|
# 3. Cross-layer Pipeline Buffers - Double-buffering
|
||||||
|
self.layer_k_buffer_a = torch.zeros(
|
||||||
|
max_prefill_blocks, block_size, kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
) # ~512 MB (固定)
|
||||||
|
self.layer_k_buffer_b = torch.zeros(...) # ~512 MB (固定)
|
||||||
|
|
||||||
|
# 4. Per-layer Decode Buffer
|
||||||
|
self.decode_k_buffer = torch.zeros(
|
||||||
|
num_layers, block_size, kv_heads, head_dim,
|
||||||
|
dtype=dtype, device="cuda"
|
||||||
|
) # ~58 MB (固定)
|
||||||
|
|
||||||
|
# GPU 总计:~1.6 GB (固定,不随序列长度增长)
|
||||||
|
```
|
||||||
|
|
||||||
|
### CPU Memory Layout
|
||||||
|
|
||||||
|
```python
|
||||||
|
# CPU Cache - 有 block 维度
|
||||||
|
self.k_cache_cpu = torch.zeros(
|
||||||
|
num_layers,
|
||||||
|
num_cpu_blocks, # 随序列长度缩放
|
||||||
|
block_size,
|
||||||
|
kv_heads,
|
||||||
|
head_dim,
|
||||||
|
dtype=dtype, device="cpu", pin_memory=True
|
||||||
|
)
|
||||||
|
# 128K tokens: ~2.9 GB
|
||||||
|
# 1M tokens: ~5.8 GB
|
||||||
|
# 4M tokens: ~23.3 GB
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Chunked Prefill 流程
|
||||||
|
|
||||||
|
### Prefill 阶段
|
||||||
|
|
||||||
|
```
|
||||||
|
For each chunk:
|
||||||
|
├── 1. Prepare chunk input (block_size tokens)
|
||||||
|
├── 2. Get ring buffer slot: slot = chunk_idx % num_gpu_blocks
|
||||||
|
├── 3. Load previous KV chunks to ring slots[1..N-1]
|
||||||
|
├── 4. Model Forward (all layers)
|
||||||
|
│ For each layer:
|
||||||
|
│ ├── Load previous KV from ring slots
|
||||||
|
│ ├── Compute attention (current chunk + previous)
|
||||||
|
│ ├── Write KV to prefill_buffer[layer_id] ← Per-layer!
|
||||||
|
│ └── Async offload to CPU (parallel across layers)
|
||||||
|
├── 5. Merge attention outputs (LSE)
|
||||||
|
└── 6. Record compute done for slot
|
||||||
|
|
||||||
|
Key: Per-layer prefill buffer → Layer 0 offload || Layer 1 compute || Layer 2 load ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### Decode 阶段
|
||||||
|
|
||||||
|
```
|
||||||
|
├── 1. Setup pipeline: preload Layer 0 to buffer_a
|
||||||
|
├── 2. For each layer:
|
||||||
|
│ ├── Get KV from pipeline buffer (a or b)
|
||||||
|
│ ├── Trigger preload of next layer to other buffer
|
||||||
|
│ ├── Compute attention
|
||||||
|
│ └── Store to decode buffer
|
||||||
|
└── 3. End pipeline
|
||||||
|
|
||||||
|
Key: Double-buffering → Layer N compute || Layer N+1 load
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 合并策略
|
||||||
|
|
||||||
|
### 基础分支选择:tzj/vs_offload
|
||||||
|
|
||||||
|
**原因**:
|
||||||
|
1. 更完善的文档系统
|
||||||
|
2. 更完整的 sparse attention 实现(QUEST, XAttention 等)
|
||||||
|
3. 更清晰的代码组织和注释
|
||||||
|
4. 更活跃的开发维护
|
||||||
|
|
||||||
|
### 移植策略
|
||||||
|
|
||||||
|
**从 tzj/minference 移植**:
|
||||||
|
1. GPU cache 内存布局(无 layer 维度,block-based)
|
||||||
|
2. Per-layer prefill buffer
|
||||||
|
3. Cross-layer pipeline buffers
|
||||||
|
4. Chunked prefill 流程
|
||||||
|
5. LSE 在线合并机制
|
||||||
|
|
||||||
|
**保留 tzj/vs_offload 优势**:
|
||||||
|
1. 文档系统
|
||||||
|
2. Sparse policy 架构
|
||||||
|
3. 代码组织和注释
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Sparse Policy 策略
|
||||||
|
|
||||||
|
**策略**:保留架构,现阶段仅实现 FULL
|
||||||
|
|
||||||
|
- **保留** sparse policy 的架构设计和接口
|
||||||
|
- **预留** 扩展接口给未来的 QUEST 等其他策略
|
||||||
|
- **现阶段仅实现** FULL 策略,确保正确性和稳定性
|
||||||
|
|
||||||
|
### 实现
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SparsePolicy(ABC):
|
||||||
|
@property
|
||||||
|
def supports_prefill(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_decode(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_prefill_offload(self, cpu_block_id, layer_id, k_cache, num_valid_tokens):
|
||||||
|
"""预留给未来策略(如 QUEST 收集元数据)"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def select_blocks(self, available_blocks, context) -> List[int]:
|
||||||
|
"""FULL: 返回所有可用块"""
|
||||||
|
return available_blocks
|
||||||
|
|
||||||
|
class FullAttentionPolicy(SparsePolicy):
|
||||||
|
@property
|
||||||
|
def supports_prefill(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_decode(self) -> bool:
|
||||||
|
return True
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 关键 API
|
||||||
|
|
||||||
|
### Ring Buffer 管理
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Prefill 阶段
|
||||||
|
get_write_slot_for_prefill(chunk_idx) -> slot_idx
|
||||||
|
get_load_slots_for_prefill(write_slot_idx) -> [slot_ids]
|
||||||
|
|
||||||
|
# Decode 阶段
|
||||||
|
get_load_slots_for_decode() -> [slot_ids] (excludes decode_slot)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Per-layer 操作
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 加载
|
||||||
|
load_to_slot_layer(slot_idx, layer_id, cpu_block_id)
|
||||||
|
wait_slot_layer(slot_idx)
|
||||||
|
|
||||||
|
# Prefill buffer
|
||||||
|
get_prefill_buffer(layer_id) -> (k, v)
|
||||||
|
offload_prefill_buffer_async(layer_id, cpu_block_id, num_tokens)
|
||||||
|
wait_prefill_offload(layer_id)
|
||||||
|
|
||||||
|
# Pipeline
|
||||||
|
start_decode_pipeline(cpu_block_ids)
|
||||||
|
get_decode_layer_kv(layer_id, num_blocks) -> (k, v)
|
||||||
|
end_decode_pipeline()
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 实施阶段
|
||||||
|
|
||||||
|
### Phase 1: 内存布局重构
|
||||||
|
- 修改 GPU cache 为 unified ring buffer
|
||||||
|
- 添加 per-layer prefill buffer
|
||||||
|
- 添加 cross-layer pipeline buffers
|
||||||
|
|
||||||
|
### Phase 2: API 实现
|
||||||
|
- 实现 ring buffer slot 管理 API
|
||||||
|
- 实现 per-layer prefill offload API
|
||||||
|
- 实现 cross-layer pipeline API
|
||||||
|
|
||||||
|
### Phase 3: 集成到 Attention Layer
|
||||||
|
- 修改 attention forward 流程
|
||||||
|
- 集成 per-layer prefill buffer
|
||||||
|
- 集成 cross-layer pipeline
|
||||||
|
|
||||||
|
### Phase 4: 集成到 Model Runner
|
||||||
|
- 实现 chunked prefill 流程
|
||||||
|
- 集成 LSE 合并
|
||||||
|
- 优化流水线
|
||||||
|
|
||||||
|
### Phase 5: Sparse Policy 集成(FULL)
|
||||||
|
- 设计统一的策略接口
|
||||||
|
- 实现 FullAttentionPolicy
|
||||||
|
- 预留 QUEST 等未来策略的扩展接口
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 关键决策
|
||||||
|
|
||||||
|
1. **Block-Based 设计优先**:支持任意长度推理的核心
|
||||||
|
2. **采用 tzj/minference 的内存布局**:GPU cache 无 layer 维度 + block-based
|
||||||
|
3. **以 tzj/vs_offload 为基础分支**:更好的文档和代码组织
|
||||||
|
4. **分阶段合并策略**:降低复杂度,便于验证
|
||||||
|
5. **Sparse Policy - FULL 优先**:保留架构,现阶段仅实现 FULL
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 预期结果
|
||||||
|
|
||||||
|
### 内存使用(28层模型,block_size=1024)
|
||||||
|
|
||||||
|
| 组件 | 内存 |
|
||||||
|
|------|------|
|
||||||
|
| GPU Unified Ring Buffer | ~4 MB |
|
||||||
|
| GPU Per-layer Prefill Buffer | ~58 MB |
|
||||||
|
| GPU Pipeline Buffers (×2) | ~1 GB |
|
||||||
|
| GPU Decode Buffer | ~58 MB |
|
||||||
|
| **GPU 总计** | **~1.6 GB (固定)** |
|
||||||
|
| CPU Cache (4M tokens) | ~23.3 GB |
|
||||||
|
| **总计 (4M tokens)** | **~24.9 GB** ✅ 适配 24GB RTX 3090 |
|
||||||
|
|
||||||
|
### 性能支持
|
||||||
|
|
||||||
|
- ✅ 支持 4M, 8M, 16M+ tokens 的推理
|
||||||
|
- ✅ GPU 内存固定,不随序列长度增长
|
||||||
|
- ✅ 完全并行的 layerwise offload
|
||||||
|
- ✅ Cross-layer 流水线优化
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 参考
|
||||||
|
|
||||||
|
- **OffloadEngine**: `nanovllm/kvcache/offload_engine.py`
|
||||||
|
- **Attention Layer**: `nanovllm/layers/attention.py`
|
||||||
|
- **Model Runner**: `nanovllm/engine/model_runner.py`
|
||||||
|
- **Sparse Policy**: `nanovllm/kvcache/sparse/policy.py`
|
||||||
@@ -62,7 +62,6 @@ class Config:
|
|||||||
xattn_keep_sink: bool = False # Always keep first block (sink tokens)
|
xattn_keep_sink: bool = False # Always keep first block (sink tokens)
|
||||||
xattn_keep_recent: bool = False # Always keep recent diagonal blocks
|
xattn_keep_recent: bool = False # Always keep recent diagonal blocks
|
||||||
xattn_norm: float = 1.0 # Normalization factor for attention scores
|
xattn_norm: float = 1.0 # Normalization factor for attention scores
|
||||||
xattn_use_bsa: bool = True # Use Block Sparse Attention library (requires installation)
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert os.path.isdir(self.model)
|
assert os.path.isdir(self.model)
|
||||||
|
|||||||
@@ -57,8 +57,8 @@ class ModelRunner:
|
|||||||
load_model(self.model, config.model)
|
load_model(self.model, config.model)
|
||||||
self.sampler = GreedySampler()
|
self.sampler = GreedySampler()
|
||||||
|
|
||||||
# Initialize attention_policy before warmup (will be configured in allocate_kv_cache)
|
# Initialize sparse_prefill_policy before warmup (will be configured in allocate_kv_cache)
|
||||||
self.attention_policy = None
|
self.sparse_prefill_policy = None
|
||||||
|
|
||||||
#> Disable warmup for debugging
|
#> Disable warmup for debugging
|
||||||
self.warmup_model()
|
self.warmup_model()
|
||||||
@@ -178,35 +178,38 @@ class ModelRunner:
|
|||||||
# Create KV cache manager using factory
|
# Create KV cache manager using factory
|
||||||
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
|
self.kvcache_manager: KVCacheManager = create_kvcache_manager(config)
|
||||||
|
|
||||||
# Create attention policy (always, including FULL)
|
# Create sparse prefill policy
|
||||||
# In layerwise offload mode, all attention goes through the policy
|
# This is used for both GPU-only and CPU offload modes when policy supports prefill
|
||||||
from nanovllm.kvcache.sparse import create_attention_policy
|
self.sparse_prefill_policy = None
|
||||||
|
if config.sparse_policy != SparsePolicyType.FULL:
|
||||||
|
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||||
|
|
||||||
# Get policy-specific parameters based on type
|
# Get policy-specific parameters based on type
|
||||||
if config.sparse_policy == SparsePolicyType.XATTN:
|
if config.sparse_policy == SparsePolicyType.XATTN:
|
||||||
policy_kwargs = {
|
policy_kwargs = {
|
||||||
"stride": config.xattn_stride,
|
"stride": config.xattn_stride,
|
||||||
"threshold": config.xattn_threshold,
|
"threshold": config.xattn_threshold,
|
||||||
"chunk_size": config.xattn_chunk_size,
|
"chunk_size": config.xattn_chunk_size,
|
||||||
"use_triton": config.xattn_use_triton,
|
"use_triton": config.xattn_use_triton,
|
||||||
"keep_sink": config.xattn_keep_sink,
|
"keep_sink": config.xattn_keep_sink,
|
||||||
"keep_recent": config.xattn_keep_recent,
|
"keep_recent": config.xattn_keep_recent,
|
||||||
"norm": config.xattn_norm,
|
"norm": config.xattn_norm,
|
||||||
"use_bsa": config.xattn_use_bsa,
|
}
|
||||||
}
|
else: # MINFERENCE or others
|
||||||
elif config.sparse_policy == SparsePolicyType.MINFERENCE:
|
policy_kwargs = {
|
||||||
policy_kwargs = {
|
"vertical_size": config.minference_vertical_size,
|
||||||
"vertical_size": config.minference_vertical_size,
|
"slash_size": config.minference_slash_size,
|
||||||
"slash_size": config.minference_slash_size,
|
"adaptive_budget": config.minference_adaptive_budget,
|
||||||
"adaptive_budget": config.minference_adaptive_budget,
|
"num_sink_tokens": config.minference_num_sink_tokens,
|
||||||
"num_sink_tokens": config.minference_num_sink_tokens,
|
"num_recent_diags": config.minference_num_recent_diags,
|
||||||
"num_recent_diags": config.minference_num_recent_diags,
|
}
|
||||||
}
|
|
||||||
else: # FULL or QUEST
|
|
||||||
policy_kwargs = {}
|
|
||||||
|
|
||||||
self.attention_policy = create_attention_policy(config.sparse_policy, **policy_kwargs)
|
policy = create_sparse_policy(config.sparse_policy, **policy_kwargs)
|
||||||
logger.info(f"Attention policy: {self.attention_policy}")
|
|
||||||
|
# Only use if policy supports sparse prefill
|
||||||
|
if policy.supports_prefill:
|
||||||
|
self.sparse_prefill_policy = policy
|
||||||
|
logger.info(f"Sparse prefill policy enabled: {self.sparse_prefill_policy}")
|
||||||
|
|
||||||
# Allocate cache through manager
|
# Allocate cache through manager
|
||||||
self.kvcache_manager.allocate_cache(
|
self.kvcache_manager.allocate_cache(
|
||||||
@@ -392,7 +395,7 @@ class ModelRunner:
|
|||||||
|
|
||||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
|
||||||
slot_mapping, None, block_tables,
|
slot_mapping, None, block_tables,
|
||||||
attention_policy=self.attention_policy)
|
sparse_prefill_policy=self.sparse_prefill_policy)
|
||||||
return input_ids, positions
|
return input_ids, positions
|
||||||
|
|
||||||
def prepare_decode(self, seqs: list[Sequence]):
|
def prepare_decode(self, seqs: list[Sequence]):
|
||||||
@@ -589,11 +592,21 @@ class ModelRunner:
|
|||||||
# RoPE
|
# RoPE
|
||||||
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
||||||
|
|
||||||
# Compute attention using policy (uses k, v directly - before store!)
|
# Sparse or Full attention (uses k, v directly - before store!)
|
||||||
attn_output = self.attention_policy.compute_prefill(
|
if self.sparse_prefill_policy is not None:
|
||||||
q, k, v, layer_id,
|
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
||||||
softmax_scale=layer.self_attn.attn.scale,
|
q, k, v, layer_id
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=total_tokens,
|
||||||
|
max_seqlen_k=total_tokens,
|
||||||
|
softmax_scale=layer.self_attn.attn.scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
# O projection
|
# O projection
|
||||||
attn_output = attn_output.view(total_tokens, -1)
|
attn_output = attn_output.view(total_tokens, -1)
|
||||||
@@ -859,11 +872,23 @@ class ModelRunner:
|
|||||||
# RoPE
|
# RoPE
|
||||||
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
q, k = layer.self_attn.rotary_emb(positions, q, k)
|
||||||
|
|
||||||
# Compute attention using policy
|
# Sparse or Full attention
|
||||||
attn_output = self.attention_policy.compute_prefill(
|
if self.sparse_prefill_policy is not None:
|
||||||
q, k, v, layer_id,
|
# MInference or other sparse prefill policy
|
||||||
softmax_scale=layer.self_attn.attn.scale,
|
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
||||||
)
|
q, k, v, layer_id
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Full attention using FlashAttention
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=total_tokens,
|
||||||
|
max_seqlen_k=total_tokens,
|
||||||
|
softmax_scale=layer.self_attn.attn.scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
# O projection
|
# O projection
|
||||||
attn_output = attn_output.view(total_tokens, -1)
|
attn_output = attn_output.view(total_tokens, -1)
|
||||||
|
|||||||
@@ -1,56 +1,49 @@
|
|||||||
"""
|
"""
|
||||||
Attention Policy module for layerwise offload mode.
|
Sparse Attention Policy module.
|
||||||
|
|
||||||
Provides pluggable policies for attention computation:
|
Provides pluggable policies for selecting which KV blocks to load
|
||||||
- FullAttentionPolicy: Standard FlashAttention (no sparsity)
|
during chunked attention with CPU offload.
|
||||||
- XAttentionPolicy: Sparse prefill using XAttention algorithm
|
|
||||||
- MInferencePolicy: MInference sparse attention
|
|
||||||
- QuestPolicy: Quest block selection (for chunked offload)
|
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from nanovllm.kvcache.sparse import create_attention_policy, SparsePolicyType
|
from nanovllm.kvcache.sparse import create_sparse_policy, SparsePolicyType
|
||||||
|
|
||||||
# Create policy using factory function
|
# Create policy using factory function
|
||||||
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
|
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=8)
|
||||||
|
|
||||||
# Use policy for attention
|
|
||||||
attn_output = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
|
|
||||||
|
|
||||||
# Or create custom policy
|
# Or create custom policy
|
||||||
class MyPolicy(AttentionPolicy):
|
class MyPolicy(SparsePolicy):
|
||||||
supports_prefill = True
|
supports_prefill = True
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
|
|
||||||
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
def select_blocks(self, available_blocks, ctx):
|
||||||
# Custom attention computation
|
return available_blocks[:5] # Just first 5 blocks
|
||||||
...
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from nanovllm.config import SparsePolicyType
|
from nanovllm.config import SparsePolicyType
|
||||||
from nanovllm.kvcache.sparse.policy import AttentionPolicy, SparsePolicy, PolicyContext
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||||
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
from nanovllm.kvcache.sparse.quest import QuestPolicy, QuestConfig, BlockMetadataManager
|
||||||
from nanovllm.kvcache.sparse.minference import MInferencePolicy
|
from nanovllm.kvcache.sparse.minference import MInferencePolicy
|
||||||
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
|
from nanovllm.kvcache.sparse.xattn import XAttentionPolicy
|
||||||
|
|
||||||
|
|
||||||
def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> AttentionPolicy:
|
def create_sparse_policy(policy_type: SparsePolicyType, **kwargs) -> SparsePolicy:
|
||||||
"""
|
"""
|
||||||
Create an attention policy instance from an enum type.
|
Create a sparse policy instance from an enum type.
|
||||||
|
|
||||||
All attention (including full attention) goes through a policy in layerwise
|
The returned policy is not yet initialized. Call policy.initialize()
|
||||||
offload mode. The policy is responsible for computing prefill/decode attention.
|
or let the framework call it during KV cache allocation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
policy_type: SparsePolicyType enum value (FULL, XATTN, MINFERENCE, QUEST)
|
policy_type: SparsePolicyType enum value
|
||||||
**kwargs: Policy-specific configuration options
|
**kwargs: Policy-specific configuration options
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AttentionPolicy instance
|
SparsePolicy instance (not initialized)
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
policy = create_attention_policy(SparsePolicyType.XATTN, threshold=0.9)
|
policy = create_sparse_policy(SparsePolicyType.QUEST, topk_blocks=4)
|
||||||
attn_out = policy.compute_prefill(q, k, v, layer_id, softmax_scale)
|
policy.initialize(num_layers=28, num_kv_heads=8, ...)
|
||||||
"""
|
"""
|
||||||
if policy_type == SparsePolicyType.FULL:
|
if policy_type == SparsePolicyType.FULL:
|
||||||
return FullAttentionPolicy()
|
return FullAttentionPolicy()
|
||||||
@@ -82,32 +75,21 @@ def create_attention_policy(policy_type: SparsePolicyType, **kwargs) -> Attentio
|
|||||||
keep_sink=kwargs.get("keep_sink", False),
|
keep_sink=kwargs.get("keep_sink", False),
|
||||||
keep_recent=kwargs.get("keep_recent", False),
|
keep_recent=kwargs.get("keep_recent", False),
|
||||||
norm=kwargs.get("norm", 1.0),
|
norm=kwargs.get("norm", 1.0),
|
||||||
use_bsa=kwargs.get("use_bsa", True),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown policy type: {policy_type}")
|
raise ValueError(f"Unknown policy type: {policy_type}")
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias
|
|
||||||
create_sparse_policy = create_attention_policy
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# New interface
|
|
||||||
"AttentionPolicy",
|
|
||||||
"create_attention_policy",
|
|
||||||
# Backward compatibility
|
|
||||||
"SparsePolicy",
|
"SparsePolicy",
|
||||||
"create_sparse_policy",
|
|
||||||
# Common types
|
|
||||||
"PolicyContext",
|
"PolicyContext",
|
||||||
"SparsePolicyType",
|
"SparsePolicyType",
|
||||||
# Policy implementations
|
|
||||||
"FullAttentionPolicy",
|
"FullAttentionPolicy",
|
||||||
"QuestPolicy",
|
"QuestPolicy",
|
||||||
"QuestConfig",
|
"QuestConfig",
|
||||||
"BlockMetadataManager",
|
"BlockMetadataManager",
|
||||||
"MInferencePolicy",
|
"MInferencePolicy",
|
||||||
"XAttentionPolicy",
|
"XAttentionPolicy",
|
||||||
|
"create_sparse_policy",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,21 +1,20 @@
|
|||||||
"""
|
"""
|
||||||
Full attention policy - standard FlashAttention without sparsity.
|
Full attention policy - loads all blocks (no sparsity).
|
||||||
|
|
||||||
This serves as a baseline and default policy when sparse
|
This serves as a baseline and default policy when sparse
|
||||||
attention is not needed.
|
attention is not needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import List
|
||||||
import torch
|
from .policy import SparsePolicy, PolicyContext
|
||||||
from .policy import AttentionPolicy
|
|
||||||
|
|
||||||
|
|
||||||
class FullAttentionPolicy(AttentionPolicy):
|
class FullAttentionPolicy(SparsePolicy):
|
||||||
"""
|
"""
|
||||||
Full attention policy using FlashAttention (no sparsity).
|
Full attention policy that loads all available blocks.
|
||||||
|
|
||||||
This is the default behavior with standard causal attention.
|
This is the default behavior with no sparsity - all previous
|
||||||
All tokens attend to all previous tokens.
|
KV cache blocks are loaded for each query chunk.
|
||||||
|
|
||||||
Use this as:
|
Use this as:
|
||||||
- A baseline for comparing sparse policies
|
- A baseline for comparing sparse policies
|
||||||
@@ -26,55 +25,15 @@ class FullAttentionPolicy(AttentionPolicy):
|
|||||||
# Full attention supports both prefill and decode
|
# Full attention supports both prefill and decode
|
||||||
supports_prefill = True
|
supports_prefill = True
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
|
requires_block_selection = False # Load all blocks, no selective loading
|
||||||
|
|
||||||
def estimate(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
available_blocks: List[int],
|
||||||
k: torch.Tensor,
|
ctx: PolicyContext,
|
||||||
layer_id: int,
|
) -> List[int]:
|
||||||
) -> Optional[torch.Tensor]:
|
"""Return all blocks - no sparsity."""
|
||||||
"""
|
return available_blocks
|
||||||
Full attention - no sparse mask needed.
|
|
||||||
|
|
||||||
Returns None to indicate full attention should be used.
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def compute_prefill(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
softmax_scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute full causal attention using FlashAttention.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Transformer layer index
|
|
||||||
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
|
|
||||||
seq_len = q.shape[0]
|
|
||||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
|
||||||
|
|
||||||
return flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=seq_len,
|
|
||||||
max_seqlen_k=seq_len,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "FullAttentionPolicy()"
|
return "FullAttentionPolicy()"
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ from typing import List, Tuple, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from nanovllm.kvcache.sparse.policy import AttentionPolicy, PolicyContext
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||||
|
|
||||||
|
|
||||||
class MInferencePolicy(AttentionPolicy):
|
class MInferencePolicy(SparsePolicy):
|
||||||
"""
|
"""
|
||||||
MInference sparse prefill policy using vertical + slash pattern.
|
MInference sparse prefill policy using vertical + slash pattern.
|
||||||
|
|
||||||
@@ -347,33 +347,6 @@ class MInferencePolicy(AttentionPolicy):
|
|||||||
|
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def compute_prefill(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
softmax_scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute MInference sparse prefill attention.
|
|
||||||
|
|
||||||
This is the new unified interface for attention policies.
|
|
||||||
Delegates to sparse_prefill_attention (ignores softmax_scale as MInference
|
|
||||||
computes it internally from head_dim).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Transformer layer index
|
|
||||||
softmax_scale: Softmax scaling factor (unused, computed internally)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
return self.sparse_prefill_attention(q, k, v, layer_id)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (f"MInferencePolicy("
|
return (f"MInferencePolicy("
|
||||||
f"adaptive_budget={self.adaptive_budget}, "
|
f"adaptive_budget={self.adaptive_budget}, "
|
||||||
|
|||||||
@@ -1,18 +1,13 @@
|
|||||||
"""
|
"""
|
||||||
Base class for attention policies in layerwise offload mode.
|
Base class for sparse attention policies.
|
||||||
|
|
||||||
AttentionPolicy defines the interface for all attention computation,
|
Sparse attention policies determine which KV cache blocks to load
|
||||||
including full attention and sparse attention methods like XAttention.
|
from CPU for each query chunk during chunked attention computation.
|
||||||
|
|
||||||
Key methods:
|
|
||||||
- estimate(): Compute sparse attention mask (optional, returns None for full attention)
|
|
||||||
- compute_prefill(): Compute prefill attention
|
|
||||||
- compute_decode(): Compute decode attention (default implementation provided)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Any
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Import SparsePolicyType from config to avoid circular imports
|
# Import SparsePolicyType from config to avoid circular imports
|
||||||
@@ -22,10 +17,10 @@ from nanovllm.config import SparsePolicyType
|
|||||||
@dataclass
|
@dataclass
|
||||||
class PolicyContext:
|
class PolicyContext:
|
||||||
"""
|
"""
|
||||||
Context passed to attention policy for block selection.
|
Context passed to sparse policy for block selection.
|
||||||
|
|
||||||
This dataclass contains all information needed by an attention policy
|
This dataclass contains all information needed by a sparse policy
|
||||||
for sparse estimation and attention computation.
|
to decide which blocks to load for the current query chunk.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
query_chunk_idx: int
|
query_chunk_idx: int
|
||||||
@@ -54,41 +49,40 @@ class PolicyContext:
|
|||||||
"""Total KV sequence length so far (for reference)."""
|
"""Total KV sequence length so far (for reference)."""
|
||||||
|
|
||||||
|
|
||||||
class AttentionPolicy(ABC):
|
class SparsePolicy(ABC):
|
||||||
"""
|
"""
|
||||||
Base class for attention policies in layerwise offload mode.
|
Abstract base class for sparse attention policies.
|
||||||
|
|
||||||
All attention computation goes through a policy, including both
|
Subclass this and implement select_blocks() to create custom
|
||||||
full attention and sparse attention methods.
|
sparse attention patterns. The policy receives context about
|
||||||
|
the current query chunk and returns which KV blocks to load.
|
||||||
The policy interface is designed for layerwise offload where:
|
|
||||||
- The entire KV cache for a layer is on GPU during computation
|
|
||||||
- No need for block loading from CPU during attention
|
|
||||||
- estimate() returns a sparse mask (or None for full attention)
|
|
||||||
- compute_prefill()/compute_decode() perform the actual attention
|
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
supports_prefill: Whether this policy can be used for prefill phase.
|
supports_prefill: Whether this policy can be used for prefill phase.
|
||||||
supports_decode: Whether this policy can be used for decode phase.
|
supports_decode: Whether this policy can be used for decode phase.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
class MyPolicy(AttentionPolicy):
|
class MySparsePolicy(SparsePolicy):
|
||||||
supports_prefill = True
|
supports_prefill = False # decode-only policy
|
||||||
supports_decode = True
|
supports_decode = True
|
||||||
|
|
||||||
def estimate(self, q, k, layer_id):
|
def select_blocks(self, available_blocks, ctx):
|
||||||
# Return sparse mask or None
|
# Load first block and last 2 blocks
|
||||||
return None
|
if len(available_blocks) <= 3:
|
||||||
|
return available_blocks
|
||||||
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
return [available_blocks[0]] + available_blocks[-2:]
|
||||||
# Compute attention
|
|
||||||
return flash_attn_varlen_func(q, k, v, ...)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Compatibility flags - override in subclasses
|
# Compatibility flags - override in subclasses
|
||||||
supports_prefill: bool = True
|
supports_prefill: bool = True
|
||||||
supports_decode: bool = True
|
supports_decode: bool = True
|
||||||
|
|
||||||
|
# Whether this policy requires selective block loading during decode
|
||||||
|
# If True: OffloadEngine will call select_blocks() before loading KV from CPU
|
||||||
|
# If False: OffloadEngine will load all blocks (select_blocks ignored for load)
|
||||||
|
# Example: MInference=False (only affects attention), Quest=True (affects load)
|
||||||
|
requires_block_selection: bool = False
|
||||||
|
|
||||||
def initialize(
|
def initialize(
|
||||||
self,
|
self,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
@@ -102,7 +96,7 @@ class AttentionPolicy(ABC):
|
|||||||
Initialize policy resources.
|
Initialize policy resources.
|
||||||
|
|
||||||
Called by the framework after KV cache is allocated. Override this
|
Called by the framework after KV cache is allocated. Override this
|
||||||
to create metadata structures or pre-allocate buffers.
|
to create metadata structures (e.g., BlockMetadataManager for Quest).
|
||||||
Default implementation does nothing.
|
Default implementation does nothing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -115,98 +109,76 @@ class AttentionPolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def estimate(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Estimate sparse attention mask.
|
|
||||||
|
|
||||||
For sparse policies (e.g., XAttention), computes block-level importance
|
|
||||||
and returns a boolean mask indicating which blocks to attend.
|
|
||||||
For full attention policy, returns None.
|
|
||||||
|
|
||||||
This corresponds to xattn_estimate() in COMPASS.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Transformer layer index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
|
|
||||||
or None for full attention
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_prefill(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
available_blocks: List[int],
|
||||||
k: torch.Tensor,
|
ctx: PolicyContext,
|
||||||
v: torch.Tensor,
|
) -> List[int]:
|
||||||
layer_id: int,
|
|
||||||
softmax_scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
"""
|
||||||
Compute prefill attention.
|
Select which KV blocks to load for the current query chunk.
|
||||||
|
|
||||||
The entire KV cache for this layer is on GPU. Compute attention
|
This is the core method that defines the sparse attention pattern.
|
||||||
between Q and K/V, optionally using sparse mask from estimate().
|
The returned blocks will be loaded from CPU to GPU for attention
|
||||||
|
computation against the current query chunk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
available_blocks: List of CPU block IDs that contain KV cache
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
from previous chunks. These are ordered by
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
their position in the sequence.
|
||||||
layer_id: Transformer layer index
|
ctx: PolicyContext with information about the current query
|
||||||
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
chunk, layer, phase (prefill/decode), etc.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
List of block IDs to load (must be a subset of available_blocks).
|
||||||
|
The order may affect performance (sequential access is faster).
|
||||||
|
Returning [] means no previous blocks will be loaded.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def compute_decode(
|
def on_prefill_offload(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
cpu_block_id: int,
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
softmax_scale: float,
|
k_cache: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
num_valid_tokens: int,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Compute decode attention.
|
Hook called when a block is offloaded during prefill phase.
|
||||||
|
|
||||||
KV is provided from ring buffer, containing prefill tokens + decoded tokens.
|
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||||
Default implementation uses FlashAttention.
|
Override this to collect metadata about blocks (e.g., min/max keys
|
||||||
|
for Quest-style selection). Default implementation does nothing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: Query tensor [1, num_heads, head_dim]
|
cpu_block_id: The CPU block ID that will be written
|
||||||
k: Key tensor [context_len+1, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [context_len+1, num_kv_heads, head_dim]
|
|
||||||
layer_id: Transformer layer index
|
layer_id: Transformer layer index
|
||||||
softmax_scale: Softmax scaling factor
|
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||||
|
num_valid_tokens: Number of valid tokens in this block
|
||||||
Returns:
|
|
||||||
Attention output [1, num_heads, head_dim]
|
|
||||||
"""
|
"""
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
pass
|
||||||
|
|
||||||
context_len = k.shape[0]
|
def on_decode_offload(
|
||||||
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
self,
|
||||||
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
cpu_block_id: int,
|
||||||
|
layer_id: int,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
num_valid_tokens: int,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Hook called when a block is offloaded during decode phase.
|
||||||
|
|
||||||
return flash_attn_varlen_func(
|
Called BEFORE GPU→CPU copy, while k_cache is still on GPU.
|
||||||
q, k, v,
|
Override this to update metadata about blocks. Default implementation
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
does nothing.
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=1,
|
Args:
|
||||||
max_seqlen_k=context_len,
|
cpu_block_id: The CPU block ID that will be written
|
||||||
softmax_scale=softmax_scale,
|
layer_id: Transformer layer index
|
||||||
causal=False,
|
k_cache: Key cache tensor [block_size, num_kv_heads, head_dim] (on GPU)
|
||||||
)
|
num_valid_tokens: Number of valid tokens in this block
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""
|
"""
|
||||||
@@ -217,9 +189,32 @@ class AttentionPolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def sparse_prefill_attention(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
layer_id: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute sparse attention for prefill phase.
|
||||||
|
|
||||||
|
This method is called when supports_prefill=True and the policy
|
||||||
|
is used for GPU-only sparse prefill (no CPU offload).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||||
|
layer_id: Current transformer layer index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not implement sparse_prefill_attention. "
|
||||||
|
"Set supports_prefill=False or implement this method."
|
||||||
|
)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"{self.__class__.__name__}()"
|
return f"{self.__class__.__name__}()"
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias
|
|
||||||
SparsePolicy = AttentionPolicy
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Tuple, Optional
|
from typing import List, Tuple, Optional
|
||||||
from .policy import AttentionPolicy, PolicyContext
|
from .policy import SparsePolicy, PolicyContext
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -137,7 +137,7 @@ class QuestConfig:
|
|||||||
"""Always include this many recent blocks (last N blocks), in addition to Top-K."""
|
"""Always include this many recent blocks (last N blocks), in addition to Top-K."""
|
||||||
|
|
||||||
|
|
||||||
class QuestPolicy(AttentionPolicy):
|
class QuestPolicy(SparsePolicy):
|
||||||
"""
|
"""
|
||||||
Quest-style Top-K block selection using min/max key bounds.
|
Quest-style Top-K block selection using min/max key bounds.
|
||||||
|
|
||||||
@@ -317,25 +317,6 @@ class QuestPolicy(AttentionPolicy):
|
|||||||
if self.metadata is not None:
|
if self.metadata is not None:
|
||||||
self.metadata.reset()
|
self.metadata.reset()
|
||||||
|
|
||||||
def compute_prefill(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
softmax_scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Quest does not support prefill - raises error.
|
|
||||||
|
|
||||||
Quest is a decode-only policy for selective block loading.
|
|
||||||
For prefill, use FullAttentionPolicy or XAttentionPolicy.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
"QuestPolicy does not support prefill. "
|
|
||||||
"Use FullAttentionPolicy or XAttentionPolicy for prefill."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"QuestPolicy(topk={self.config.topk_blocks}, "
|
f"QuestPolicy(topk={self.config.topk_blocks}, "
|
||||||
|
|||||||
@@ -4,56 +4,48 @@ XAttention sparse attention policy for nano-vllm.
|
|||||||
Implements the XAttention algorithm from COMPASS, using chunked estimation
|
Implements the XAttention algorithm from COMPASS, using chunked estimation
|
||||||
and block sparse attention for efficient long-context inference.
|
and block sparse attention for efficient long-context inference.
|
||||||
|
|
||||||
Architecture:
|
|
||||||
XAttention = Estimate (Triton) + Compute (BSA)
|
|
||||||
- Estimate: xattn_estimate() computes block-level importance scores
|
|
||||||
- Compute: block_sparse_attn_func() executes sparse attention
|
|
||||||
|
|
||||||
Reference: COMPASS/compass/src/Xattention.py
|
Reference: COMPASS/compass/src/Xattention.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from nanovllm.kvcache.sparse.policy import AttentionPolicy
|
from nanovllm.kvcache.sparse.policy import SparsePolicy, PolicyContext
|
||||||
|
from nanovllm.kvcache.sparse.kernels import (
|
||||||
# BSA block size is fixed at 128 (hardcoded in block_sparse_attn)
|
flat_group_gemm_fuse_reshape,
|
||||||
BSA_BLOCK_SIZE = 128
|
softmax_fuse_block_sum,
|
||||||
|
)
|
||||||
|
from nanovllm.kvcache.sparse.utils import find_blocks_chunked
|
||||||
|
|
||||||
|
|
||||||
class XAttentionPolicy(AttentionPolicy):
|
class XAttentionPolicy(SparsePolicy):
|
||||||
"""
|
"""
|
||||||
XAttention sparse prefill policy using chunked estimation + block sparse attention.
|
XAttention sparse prefill policy using chunked estimation + block sparse attention.
|
||||||
|
|
||||||
This policy estimates sparse attention patterns by:
|
This policy estimates sparse attention patterns by:
|
||||||
1. Chunked QK computation using Triton kernels (via nanovllm.ops.xattn)
|
1. Chunked QK computation using Triton kernels
|
||||||
2. Block-wise softmax with importance scores
|
2. Block-wise softmax with importance scores
|
||||||
3. Block selection based on threshold
|
3. Block selection based on threshold
|
||||||
4. Block sparse attention computation using MIT-HAN-LAB BSA library
|
4. Block sparse attention computation
|
||||||
|
|
||||||
The key method is estimate() which calls xattn_estimate() from nanovllm.ops
|
|
||||||
to compute the sparse attention mask.
|
|
||||||
|
|
||||||
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
|
Note: Requires Triton >= 2.1.0 and CUDA SM 80+ (RTX 3090, A100, H100, etc.)
|
||||||
BSA library: https://github.com/mit-han-lab/Block-Sparse-Attention
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
supports_prefill = True
|
supports_prefill = True
|
||||||
supports_decode = True # Uses default FlashAttention for decode
|
supports_decode = False # XAttention is prefill-only
|
||||||
|
requires_block_selection = False # Only affects attention computation
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
stride: int = 8,
|
stride: int = 8,
|
||||||
threshold: float = 0.9,
|
threshold: float = 0.9,
|
||||||
block_size: int = 128,
|
chunk_size: Optional[int] = None,
|
||||||
chunk_size: int = 16384,
|
|
||||||
use_triton: bool = True,
|
use_triton: bool = True,
|
||||||
keep_sink: bool = False,
|
keep_sink: bool = False,
|
||||||
keep_recent: bool = False,
|
keep_recent: bool = False,
|
||||||
norm: float = 1.0,
|
norm: float = 1.0,
|
||||||
use_bsa: bool = True,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize XAttention policy.
|
Initialize XAttention policy.
|
||||||
@@ -61,28 +53,19 @@ class XAttentionPolicy(AttentionPolicy):
|
|||||||
Args:
|
Args:
|
||||||
stride: Stride for reorganizing Q/K (default: 8)
|
stride: Stride for reorganizing Q/K (default: 8)
|
||||||
threshold: Block selection threshold, 0-1 (default: 0.9)
|
threshold: Block selection threshold, 0-1 (default: 0.9)
|
||||||
block_size: Block size for sparse attention (default: 128, must match BSA)
|
chunk_size: Chunk size for estimation (auto if None)
|
||||||
chunk_size: Chunk size for estimation (default: 16384)
|
|
||||||
use_triton: Use Triton kernels (requires SM 80+)
|
use_triton: Use Triton kernels (requires SM 80+)
|
||||||
keep_sink: Always keep first block (sink tokens)
|
keep_sink: Always keep first block (sink tokens)
|
||||||
keep_recent: Always keep recent diagonal blocks
|
keep_recent: Always keep recent diagonal blocks
|
||||||
norm: Normalization factor for attention scores
|
norm: Normalization factor for attention scores
|
||||||
use_bsa: Use Block Sparse Attention library (default: True)
|
|
||||||
"""
|
"""
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.block_size = block_size
|
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
self.use_triton = use_triton
|
self.use_triton = use_triton
|
||||||
self.keep_sink = keep_sink
|
self.keep_sink = keep_sink
|
||||||
self.keep_recent = keep_recent
|
self.keep_recent = keep_recent
|
||||||
self.norm = norm
|
self.norm = norm
|
||||||
self.use_bsa = use_bsa
|
|
||||||
|
|
||||||
# BSA requires block_size = 128
|
|
||||||
if self.use_bsa and self.block_size != BSA_BLOCK_SIZE:
|
|
||||||
print(f"XAttention: BSA requires block_size=128, adjusting from {self.block_size}")
|
|
||||||
self.block_size = BSA_BLOCK_SIZE
|
|
||||||
|
|
||||||
# Check Triton availability
|
# Check Triton availability
|
||||||
if self.use_triton:
|
if self.use_triton:
|
||||||
@@ -96,206 +79,379 @@ class XAttentionPolicy(AttentionPolicy):
|
|||||||
self.use_triton = False
|
self.use_triton = False
|
||||||
print("XAttention: Triton not available. Falling back to PyTorch.")
|
print("XAttention: Triton not available. Falling back to PyTorch.")
|
||||||
|
|
||||||
# Check BSA availability
|
def select_blocks(
|
||||||
if self.use_bsa:
|
|
||||||
try:
|
|
||||||
from block_sparse_attn import block_sparse_attn_func
|
|
||||||
except ImportError:
|
|
||||||
self.use_bsa = False
|
|
||||||
print("XAttention: block_sparse_attn not available. Falling back to FlashAttention.")
|
|
||||||
|
|
||||||
def estimate(
|
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
available_blocks: List[int],
|
||||||
k: torch.Tensor,
|
ctx: PolicyContext,
|
||||||
layer_id: int,
|
) -> List[int]:
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""
|
"""
|
||||||
Estimate sparse attention mask using XAttention algorithm.
|
Select blocks for decode phase.
|
||||||
|
|
||||||
Calls xattn_estimate() from nanovllm.ops.xattn to compute block-level
|
XAttention is prefill-only, so this method is only used as a fallback.
|
||||||
importance scores and generate a sparse boolean mask.
|
Returns all available blocks by default.
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Transformer layer index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sparse_mask: [batch, num_heads, q_blocks, k_blocks] boolean mask,
|
|
||||||
or None if estimation fails (fallback to full attention)
|
|
||||||
"""
|
"""
|
||||||
try:
|
# XAttention is prefill-only, but we need to implement this abstract method
|
||||||
from nanovllm.ops.xattn import xattn_estimate
|
# Since requires_block_selection=False, this won't be called for loading
|
||||||
|
return available_blocks
|
||||||
|
|
||||||
seq_len, num_heads, head_dim = q.shape
|
def sparse_prefill_attention(
|
||||||
num_kv_heads = k.shape[1]
|
|
||||||
|
|
||||||
# Convert to [batch, heads, seq, dim] format expected by xattn_estimate
|
|
||||||
q_bhsd = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, seq_len, head_dim]
|
|
||||||
k_bhsd = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, seq_len, head_dim]
|
|
||||||
|
|
||||||
# Handle GQA: expand k to match q heads for estimation
|
|
||||||
if num_kv_heads != num_heads:
|
|
||||||
# GQA: expand k by repeating
|
|
||||||
repeat_factor = num_heads // num_kv_heads
|
|
||||||
k_bhsd = k_bhsd.repeat(1, repeat_factor, 1, 1)
|
|
||||||
|
|
||||||
# Call xattn_estimate
|
|
||||||
attn_sums, sparse_mask = xattn_estimate(
|
|
||||||
q_bhsd, k_bhsd,
|
|
||||||
block_size=self.block_size,
|
|
||||||
stride=self.stride,
|
|
||||||
norm=self.norm,
|
|
||||||
threshold=self.threshold,
|
|
||||||
chunk_size=self.chunk_size,
|
|
||||||
use_triton=self.use_triton,
|
|
||||||
causal=True,
|
|
||||||
keep_sink=self.keep_sink,
|
|
||||||
keep_recent=self.keep_recent,
|
|
||||||
)
|
|
||||||
|
|
||||||
return sparse_mask
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# If estimation fails, return None to use full attention
|
|
||||||
print(f"XAttention estimate failed: {e}, falling back to full attention")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def compute_prefill(
|
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
k: torch.Tensor,
|
k: torch.Tensor,
|
||||||
v: torch.Tensor,
|
v: torch.Tensor,
|
||||||
layer_id: int,
|
layer_id: int,
|
||||||
softmax_scale: float,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute XAttention sparse prefill attention.
|
Compute XAttention sparse attention for prefill.
|
||||||
|
|
||||||
Flow:
|
|
||||||
1. Call estimate() to get sparse mask
|
|
||||||
2. If mask is None or BSA unavailable, use full FlashAttention
|
|
||||||
3. Otherwise, use block_sparse_attn_func with mask
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
q: Query tensor [seq_len, num_heads, head_dim]
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
||||||
layer_id: Transformer layer index
|
layer_id: Current transformer layer index
|
||||||
softmax_scale: Softmax scaling factor
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
Attention output [seq_len, num_heads, head_dim]
|
||||||
"""
|
"""
|
||||||
# If BSA is disabled, use full attention directly (skip estimation)
|
seq_len = q.shape[0]
|
||||||
if not self.use_bsa:
|
num_heads = q.shape[1]
|
||||||
return self._full_attention(q, k, v, softmax_scale)
|
head_dim = q.shape[2]
|
||||||
|
|
||||||
# Step 1: Estimate sparse mask
|
|
||||||
sparse_mask = self.estimate(q, k, layer_id)
|
|
||||||
|
|
||||||
# Step 2: Compute attention
|
|
||||||
if sparse_mask is None:
|
|
||||||
# Estimation failed, fallback to full FlashAttention
|
|
||||||
return self._full_attention(q, k, v, softmax_scale)
|
|
||||||
|
|
||||||
# Use block sparse attention with mask
|
|
||||||
return self._block_sparse_attention(q, k, v, sparse_mask, softmax_scale)
|
|
||||||
|
|
||||||
def _block_sparse_attention(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
sparse_mask: torch.Tensor,
|
|
||||||
softmax_scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute block sparse attention using MIT-HAN-LAB BSA library.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
sparse_mask: Block mask [batch, num_heads, q_blocks, k_blocks]
|
|
||||||
softmax_scale: Softmax scaling factor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
from block_sparse_attn import block_sparse_attn_func
|
|
||||||
|
|
||||||
seq_len, num_heads, head_dim = q.shape
|
|
||||||
num_kv_heads = k.shape[1]
|
num_kv_heads = k.shape[1]
|
||||||
|
|
||||||
# Handle GQA: expand K/V to match Q heads
|
# Use FlashAttention directly for CPU offload mode
|
||||||
if num_kv_heads != num_heads:
|
# FlashAttention supports GQA natively
|
||||||
repeat_factor = num_heads // num_kv_heads
|
try:
|
||||||
k = k.repeat_interleave(repeat_factor, dim=1)
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
v = v.repeat_interleave(repeat_factor, dim=1)
|
|
||||||
|
|
||||||
# Cumulative sequence lengths (batch=1)
|
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
||||||
cu_seqlens_q = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
|
||||||
cu_seqlens_k = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
|
||||||
|
|
||||||
# Head mask type: 1 for all heads using block sparse
|
attn_output = flash_attn_varlen_func(
|
||||||
head_mask_type = torch.ones(num_heads, dtype=torch.int32, device=q.device)
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens,
|
||||||
|
cu_seqlens_k=cu_seqlens,
|
||||||
|
max_seqlen_q=seq_len,
|
||||||
|
max_seqlen_k=seq_len,
|
||||||
|
softmax_scale=1.0 / math.sqrt(head_dim),
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Trim sparse_mask to actual block counts
|
return attn_output
|
||||||
q_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
|
||||||
k_blocks = (seq_len + BSA_BLOCK_SIZE - 1) // BSA_BLOCK_SIZE
|
|
||||||
block_mask = sparse_mask[:, :, :q_blocks, :k_blocks].contiguous()
|
|
||||||
|
|
||||||
# Call BSA
|
except Exception as e:
|
||||||
attn_output = block_sparse_attn_func(
|
# Fallback: PyTorch SDPA (supports GQA natively)
|
||||||
q, k, v,
|
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
||||||
cu_seqlens_q, cu_seqlens_k,
|
attn_output = F.scaled_dot_product_attention(
|
||||||
head_mask_type,
|
q, k, v,
|
||||||
None, # streaming_info (left_mask)
|
attn_mask=None,
|
||||||
block_mask,
|
is_causal=True,
|
||||||
seq_len, seq_len,
|
scale=1.0 / math.sqrt(head_dim)
|
||||||
p_dropout=0.0,
|
)
|
||||||
deterministic=True,
|
return attn_output
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
is_causal=True,
|
def _xattn_offload_prefill(
|
||||||
|
self,
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
causal: bool = True,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Simplified XAttention prefill for CPU offload mode.
|
||||||
|
|
||||||
|
Uses FlashAttention with full context since chunked estimation
|
||||||
|
with full key_states requires special handling.
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, q_len, head_dim = query_states.shape
|
||||||
|
_, _, k_len, _ = key_states.shape
|
||||||
|
|
||||||
|
# Use FlashAttention with full context
|
||||||
|
# In offload mode, keys are already on CPU and loaded as needed
|
||||||
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
# Convert to [seq, heads, dim] format
|
||||||
|
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
|
||||||
|
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||||
|
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||||
|
|
||||||
|
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
||||||
|
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=q_len,
|
||||||
|
max_seqlen_k=k_len,
|
||||||
|
softmax_scale=1.0 / math.sqrt(head_dim),
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert back to [batch, seq, heads, dim]
|
||||||
|
attn_output = attn_output.unsqueeze(0).transpose(1, 2) # [1, q_len, num_heads, head_dim]
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Final fallback: PyTorch SDPA
|
||||||
|
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
||||||
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
query_states, key_states, value_states,
|
||||||
|
attn_mask=None,
|
||||||
|
is_causal=causal,
|
||||||
|
scale=1.0 / math.sqrt(head_dim)
|
||||||
|
)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
def _xattn_prefill(
|
||||||
|
self,
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
stride: int,
|
||||||
|
norm: float,
|
||||||
|
threshold: float,
|
||||||
|
block_size: int = 128,
|
||||||
|
use_triton: bool = True,
|
||||||
|
causal: bool = True,
|
||||||
|
chunk_size: Optional[int] = None,
|
||||||
|
keep_sink: bool = False,
|
||||||
|
keep_recent: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
XAttention prefill implementation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_states: [batch, num_heads, q_len, head_dim]
|
||||||
|
key_states: [batch, num_heads, k_len, head_dim]
|
||||||
|
value_states: [batch, num_heads, k_len, head_dim]
|
||||||
|
... other params
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [batch, q_len, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
batch_size, num_heads, k_len, head_dim = key_states.shape
|
||||||
|
_, _, q_len, _ = query_states.shape
|
||||||
|
|
||||||
|
# Auto-compute chunk_size if not specified
|
||||||
|
if chunk_size is None:
|
||||||
|
chunk_size = int(
|
||||||
|
max(
|
||||||
|
min(
|
||||||
|
max(2048, 1 << (k_len - 1).bit_length()),
|
||||||
|
128 * 1024 * 2048 // (1 << (k_len - 1).bit_length()),
|
||||||
|
),
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 1: Estimate sparse pattern
|
||||||
|
attn_sums, approx_simple_mask = self._xattn_estimate(
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
block_size=block_size,
|
||||||
|
stride=stride,
|
||||||
|
norm=norm,
|
||||||
|
threshold=threshold,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
use_triton=use_triton,
|
||||||
|
causal=causal,
|
||||||
|
keep_sink=keep_sink,
|
||||||
|
keep_recent=keep_recent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 2: Block sparse attention
|
||||||
|
# For now, use FlashAttention as fallback since block_sparse_attn_func may not be available
|
||||||
|
attn_output = self._block_sparse_attention_fallback(
|
||||||
|
query_states, key_states, value_states,
|
||||||
|
approx_simple_mask, block_size, q_len, k_len
|
||||||
)
|
)
|
||||||
|
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
def _full_attention(
|
def _xattn_estimate(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
query_states: torch.Tensor,
|
||||||
k: torch.Tensor,
|
key_states: torch.Tensor,
|
||||||
v: torch.Tensor,
|
block_size: int,
|
||||||
softmax_scale: float,
|
stride: int,
|
||||||
|
norm: float = 1,
|
||||||
|
softmax: bool = True,
|
||||||
|
threshold: float = 0.9,
|
||||||
|
chunk_size: int = 16384,
|
||||||
|
use_triton: bool = True,
|
||||||
|
causal: bool = True,
|
||||||
|
keep_sink: bool = False,
|
||||||
|
keep_recent: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute full causal attention using FlashAttention.
|
Estimate sparse attention pattern using chunked computation.
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
softmax_scale: Softmax scaling factor
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
attn_sums: [batch, heads, q_blocks, k_blocks] - importance scores
|
||||||
|
simple_masks: [batch, heads, q_blocks, k_blocks] - boolean masks
|
||||||
"""
|
"""
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
batch_size, num_kv_head, k_len, head_dim = key_states.shape
|
||||||
|
batch_size, num_q_head, q_len, head_dim = query_states.shape
|
||||||
|
|
||||||
seq_len = q.shape[0]
|
k_num_to_pad = ((k_len + chunk_size - 1) // chunk_size) * chunk_size - k_len
|
||||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
q_num_to_pad = ((q_len + chunk_size - 1) // chunk_size) * chunk_size - q_len
|
||||||
|
k_chunk_num = (k_len + k_num_to_pad) // chunk_size
|
||||||
|
k_block_num = (k_len + k_num_to_pad) // block_size
|
||||||
|
q_chunk_num = (q_len + q_num_to_pad) // chunk_size
|
||||||
|
q_block_num = (q_len + q_num_to_pad) // block_size
|
||||||
|
|
||||||
return flash_attn_varlen_func(
|
# Pad inputs
|
||||||
q, k, v,
|
if k_num_to_pad > 0:
|
||||||
cu_seqlens_q=cu_seqlens,
|
pad_key_states = F.pad(key_states, (0, 0, 0, k_num_to_pad), value=0)
|
||||||
cu_seqlens_k=cu_seqlens,
|
else:
|
||||||
max_seqlen_q=seq_len,
|
pad_key_states = key_states
|
||||||
max_seqlen_k=seq_len,
|
if q_num_to_pad > 0:
|
||||||
softmax_scale=softmax_scale,
|
pad_query_states = F.pad(query_states, (0, 0, 0, q_num_to_pad), value=0)
|
||||||
causal=True,
|
else:
|
||||||
)
|
pad_query_states = query_states
|
||||||
|
|
||||||
|
reshaped_chunk_size = chunk_size // stride
|
||||||
|
reshaped_block_size = block_size // stride
|
||||||
|
k_reshaped_seq_len = (k_len + k_num_to_pad) // stride
|
||||||
|
|
||||||
|
attn_sum_list = []
|
||||||
|
simple_mask_list = []
|
||||||
|
|
||||||
|
for chunk_idx in range(q_chunk_num):
|
||||||
|
if use_triton:
|
||||||
|
# Triton GEMM + Softmax
|
||||||
|
attn_weights_slice = flat_group_gemm_fuse_reshape(
|
||||||
|
pad_query_states[:, :, (chunk_idx * reshaped_chunk_size) * stride : (chunk_idx * reshaped_chunk_size + reshaped_chunk_size) * stride, :],
|
||||||
|
pad_key_states,
|
||||||
|
stride,
|
||||||
|
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
|
||||||
|
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
|
||||||
|
is_causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_sum = softmax_fuse_block_sum(
|
||||||
|
attn_weights_slice,
|
||||||
|
reshaped_block_size,
|
||||||
|
min(4096, reshaped_block_size),
|
||||||
|
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size,
|
||||||
|
(k_block_num - q_block_num) * reshaped_block_size + chunk_idx * reshaped_chunk_size + reshaped_chunk_size,
|
||||||
|
k_reshaped_seq_len - (k_num_to_pad // stride),
|
||||||
|
1.4426950408889634 / math.sqrt(head_dim) / stride / norm,
|
||||||
|
is_causal=causal,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# PyTorch fallback
|
||||||
|
chunk_size_actual = reshaped_chunk_size
|
||||||
|
chunk_start = chunk_idx * chunk_size_actual
|
||||||
|
chunk_end = chunk_start + chunk_size_actual
|
||||||
|
|
||||||
|
chunked_query = pad_query_states[:, :, chunk_start * stride:chunk_end * stride:stride, :]
|
||||||
|
attn_weights_slice = torch.matmul(chunked_query, pad_key_states.transpose(2, 3))
|
||||||
|
attn_weights_slice = attn_weights_slice / math.sqrt(head_dim) / stride / norm
|
||||||
|
|
||||||
|
if causal:
|
||||||
|
causal_mask = torch.zeros((batch_size, num_q_head, chunk_size_actual, chunk_size_actual * k_chunk_num), device=key_states.device)
|
||||||
|
causal_mask[:, :, :, -(k_num_to_pad // stride):] = float("-inf")
|
||||||
|
# ... more causal mask logic ...
|
||||||
|
attn_weights_slice = attn_weights_slice + causal_mask
|
||||||
|
|
||||||
|
attn_weights_slice = F.softmax(attn_weights_slice, dim=-1, dtype=torch.float32)
|
||||||
|
attn_sum = attn_weights_slice.view(batch_size, num_q_head, chunk_size_actual // reshaped_block_size, reshaped_block_size, -1).sum(dim=-1).sum(dim=-2)
|
||||||
|
|
||||||
|
# Find blocks based on threshold
|
||||||
|
simple_mask = find_blocks_chunked(
|
||||||
|
attn_sum,
|
||||||
|
k_block_num - q_block_num + chunk_idx * (reshaped_chunk_size // reshaped_block_size),
|
||||||
|
threshold,
|
||||||
|
None,
|
||||||
|
decoding=False,
|
||||||
|
mode="prefill",
|
||||||
|
causal=causal,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_sum_list.append(attn_sum)
|
||||||
|
simple_mask_list.append(simple_mask)
|
||||||
|
|
||||||
|
attn_sums = torch.cat(attn_sum_list, dim=-2)
|
||||||
|
simple_masks = torch.cat(simple_mask_list, dim=-2)
|
||||||
|
|
||||||
|
# Apply causal mask to block masks
|
||||||
|
if causal:
|
||||||
|
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||||
|
torch.tril(torch.ones(q_block_num, q_block_num, dtype=bool, device=key_states.device), diagonal=0),
|
||||||
|
simple_masks[:, :, -q_block_num:, -q_block_num:],
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if keep_sink:
|
||||||
|
simple_masks[:, :, 0, :] = True
|
||||||
|
|
||||||
|
if keep_recent:
|
||||||
|
eye_matrix = torch.eye(q_block_num, device=simple_masks.device, dtype=bool)
|
||||||
|
eye_matrix_expanded = eye_matrix.unsqueeze(0).unsqueeze(0).expand(1, num_q_head, q_block_num, q_block_num)
|
||||||
|
simple_masks[:, :, -q_block_num:, -q_block_num:] = torch.where(
|
||||||
|
eye_matrix_expanded, True, simple_masks[:, :, -q_block_num:, -q_block_num:]
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_sums, simple_masks
|
||||||
|
|
||||||
|
def _block_sparse_attention_fallback(
|
||||||
|
self,
|
||||||
|
query_states: torch.Tensor,
|
||||||
|
key_states: torch.Tensor,
|
||||||
|
value_states: torch.Tensor,
|
||||||
|
mask: torch.Tensor,
|
||||||
|
block_size: int,
|
||||||
|
q_len: int,
|
||||||
|
k_len: int,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Fallback implementation using FlashAttention.
|
||||||
|
|
||||||
|
Since block_sparse_attn_func may not be available in all environments,
|
||||||
|
this uses standard FlashAttention with full attention.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||||
|
|
||||||
|
batch_size, num_heads, _, head_dim = query_states.shape
|
||||||
|
|
||||||
|
# Convert to [seq, heads, dim] format
|
||||||
|
q = query_states.squeeze(0).transpose(0, 1) # [q_len, num_heads, head_dim]
|
||||||
|
k = key_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||||
|
v = value_states.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||||
|
|
||||||
|
cu_seqlens_q = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
||||||
|
cu_seqlens_k = torch.tensor([0, k_len], dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
attn_output = flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=q_len,
|
||||||
|
max_seqlen_k=k_len,
|
||||||
|
softmax_scale=1.0 / math.sqrt(head_dim),
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert back to [batch, seq, heads, dim]
|
||||||
|
attn_output = attn_output.unsqueeze(0).transpose(1, 2)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Final fallback: PyTorch SDPA
|
||||||
|
print(f"XAttention: FlashAttention fallback failed ({e}), using PyTorch SDPA")
|
||||||
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
||||||
|
attn_output = F.scaled_dot_product_attention(
|
||||||
|
query_states, key_states, value_states,
|
||||||
|
attn_mask=None,
|
||||||
|
is_causal=True,
|
||||||
|
scale=1.0 / math.sqrt(query_states.shape[-1])
|
||||||
|
)
|
||||||
|
return attn_output
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
"""Reset policy state (no state to reset for XAttention)."""
|
"""Reset policy state (no state to reset for XAttention)."""
|
||||||
@@ -305,6 +461,4 @@ class XAttentionPolicy(AttentionPolicy):
|
|||||||
return (f"XAttentionPolicy("
|
return (f"XAttentionPolicy("
|
||||||
f"stride={self.stride}, "
|
f"stride={self.stride}, "
|
||||||
f"threshold={self.threshold}, "
|
f"threshold={self.threshold}, "
|
||||||
f"block_size={self.block_size}, "
|
f"use_triton={self.use_triton})")
|
||||||
f"use_triton={self.use_triton}, "
|
|
||||||
f"use_bsa={self.use_bsa})")
|
|
||||||
|
|||||||
@@ -98,10 +98,10 @@ class Attention(nn.Module):
|
|||||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
||||||
elif context.attention_policy is not None:
|
elif context.sparse_prefill_policy is not None:
|
||||||
# Attention via policy (GPU-only) - delegate to policy
|
# Sparse prefill (GPU-only) - delegate to policy
|
||||||
o = context.attention_policy.compute_prefill(
|
o = context.sparse_prefill_policy.sparse_prefill_attention(
|
||||||
q, k, v, self.layer_id, softmax_scale=self.scale
|
q, k, v, self.layer_id
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
o = flash_attn_varlen_func(q, k, v,
|
o = flash_attn_varlen_func(q, k, v,
|
||||||
|
|||||||
@@ -1,38 +0,0 @@
|
|||||||
"""
|
|
||||||
Operators module for nano-vLLM.
|
|
||||||
|
|
||||||
This module contains low-level attention operators and kernels.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from nanovllm.ops.chunked_attention import (
|
|
||||||
flash_attn_with_lse,
|
|
||||||
merge_attention_outputs,
|
|
||||||
chunked_attention_varlen,
|
|
||||||
ChunkedPrefillState,
|
|
||||||
)
|
|
||||||
|
|
||||||
from nanovllm.ops.xattn import (
|
|
||||||
xattn_estimate,
|
|
||||||
xattn_estimate_chunked,
|
|
||||||
flat_group_gemm_fuse_reshape,
|
|
||||||
softmax_fuse_block_sum,
|
|
||||||
find_blocks_chunked,
|
|
||||||
create_causal_mask,
|
|
||||||
compute_sparsity,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# chunked_attention
|
|
||||||
"flash_attn_with_lse",
|
|
||||||
"merge_attention_outputs",
|
|
||||||
"chunked_attention_varlen",
|
|
||||||
"ChunkedPrefillState",
|
|
||||||
# xattn
|
|
||||||
"xattn_estimate",
|
|
||||||
"xattn_estimate_chunked",
|
|
||||||
"flat_group_gemm_fuse_reshape",
|
|
||||||
"softmax_fuse_block_sum",
|
|
||||||
"find_blocks_chunked",
|
|
||||||
"create_causal_mask",
|
|
||||||
"compute_sparsity",
|
|
||||||
]
|
|
||||||
@@ -1,624 +0,0 @@
|
|||||||
"""
|
|
||||||
Chunked attention implementation for CPU KV cache offloading.
|
|
||||||
|
|
||||||
This module implements flash attention with LSE (log-sum-exp) output,
|
|
||||||
enabling proper online softmax merging for chunked prefill.
|
|
||||||
|
|
||||||
Key functions:
|
|
||||||
- flash_attn_with_lse: Flash attention that returns output and LSE
|
|
||||||
- merge_attention_outputs: Merge outputs from multiple KV chunks
|
|
||||||
- chunked_prefill_attention: High-level interface for chunked attention
|
|
||||||
"""
|
|
||||||
|
|
||||||
import math
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
from typing import Tuple, List, Optional
|
|
||||||
|
|
||||||
|
|
||||||
@triton.heuristics(
|
|
||||||
{
|
|
||||||
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
|
|
||||||
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
|
|
||||||
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
@triton.jit
|
|
||||||
def _fwd_kernel_with_lse(
|
|
||||||
Q,
|
|
||||||
K,
|
|
||||||
V,
|
|
||||||
Out,
|
|
||||||
Lse,
|
|
||||||
softmax_scale,
|
|
||||||
stride_qb,
|
|
||||||
stride_qh,
|
|
||||||
stride_qm,
|
|
||||||
stride_kb,
|
|
||||||
stride_kh,
|
|
||||||
stride_kn,
|
|
||||||
stride_vb,
|
|
||||||
stride_vh,
|
|
||||||
stride_vn,
|
|
||||||
stride_ob,
|
|
||||||
stride_oh,
|
|
||||||
stride_om,
|
|
||||||
nheads,
|
|
||||||
seqlen_q,
|
|
||||||
seqlen_k,
|
|
||||||
seqlen_q_rounded,
|
|
||||||
headdim,
|
|
||||||
CACHE_KEY_SEQLEN_Q,
|
|
||||||
CACHE_KEY_SEQLEN_K,
|
|
||||||
IS_CAUSAL: tl.constexpr,
|
|
||||||
BLOCK_HEADDIM: tl.constexpr,
|
|
||||||
EVEN_M: tl.constexpr,
|
|
||||||
EVEN_N: tl.constexpr,
|
|
||||||
EVEN_HEADDIM: tl.constexpr,
|
|
||||||
BLOCK_M: tl.constexpr,
|
|
||||||
BLOCK_N: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Flash attention forward kernel with LSE output.
|
|
||||||
|
|
||||||
Implements standard Flash Attention online softmax algorithm:
|
|
||||||
- m_i: running max of attention scores
|
|
||||||
- l_i: running sum of exp(scores - m_i)
|
|
||||||
- acc_o: running sum of softmax(scores) @ V (unnormalized)
|
|
||||||
|
|
||||||
Final output: acc_o / l_i
|
|
||||||
Final LSE: m_i + log(l_i)
|
|
||||||
"""
|
|
||||||
start_m = tl.program_id(0)
|
|
||||||
off_hb = tl.program_id(1)
|
|
||||||
off_b = off_hb // nheads
|
|
||||||
off_h = off_hb % nheads
|
|
||||||
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
|
|
||||||
offs_n = tl.arange(0, BLOCK_N)
|
|
||||||
offs_d = tl.arange(0, BLOCK_HEADDIM)
|
|
||||||
|
|
||||||
# Pointers
|
|
||||||
q_ptrs = (
|
|
||||||
Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :])
|
|
||||||
)
|
|
||||||
k_ptrs = (
|
|
||||||
K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :])
|
|
||||||
)
|
|
||||||
v_ptrs = (
|
|
||||||
V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Initialize running statistics
|
|
||||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # running max
|
|
||||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) # running sum of exp
|
|
||||||
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # running output (unnormalized)
|
|
||||||
|
|
||||||
# Load Q (once per block)
|
|
||||||
if EVEN_M & EVEN_N:
|
|
||||||
if EVEN_HEADDIM:
|
|
||||||
q = tl.load(q_ptrs)
|
|
||||||
else:
|
|
||||||
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
|
|
||||||
else:
|
|
||||||
if EVEN_HEADDIM:
|
|
||||||
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
|
|
||||||
else:
|
|
||||||
q = tl.load(
|
|
||||||
q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Loop over K, V blocks
|
|
||||||
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
|
|
||||||
for start_n in range(0, end_n, BLOCK_N):
|
|
||||||
start_n = tl.multiple_of(start_n, BLOCK_N)
|
|
||||||
|
|
||||||
# Load K
|
|
||||||
if EVEN_N & EVEN_M:
|
|
||||||
if EVEN_HEADDIM:
|
|
||||||
k = tl.load(k_ptrs + start_n * stride_kn)
|
|
||||||
else:
|
|
||||||
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
|
|
||||||
else:
|
|
||||||
if EVEN_HEADDIM:
|
|
||||||
k = tl.load(
|
|
||||||
k_ptrs + start_n * stride_kn,
|
|
||||||
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
k = tl.load(
|
|
||||||
k_ptrs + start_n * stride_kn,
|
|
||||||
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compute QK^T * scale
|
|
||||||
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
|
||||||
qk += tl.dot(q, tl.trans(k))
|
|
||||||
qk *= softmax_scale
|
|
||||||
|
|
||||||
# Apply masks
|
|
||||||
if not EVEN_N:
|
|
||||||
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
|
|
||||||
if IS_CAUSAL:
|
|
||||||
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
|
|
||||||
|
|
||||||
# Online softmax: compute block max
|
|
||||||
m_ij = tl.max(qk, 1) # [BLOCK_M]
|
|
||||||
|
|
||||||
# New running max
|
|
||||||
m_new = tl.maximum(m_i, m_ij) # [BLOCK_M]
|
|
||||||
|
|
||||||
# Rescale factor for previous accumulator
|
|
||||||
alpha = tl.exp(m_i - m_new) # [BLOCK_M]
|
|
||||||
|
|
||||||
# Compute P = exp(qk - m_new)
|
|
||||||
p = tl.exp(qk - m_new[:, None]) # [BLOCK_M, BLOCK_N]
|
|
||||||
|
|
||||||
# Sum of current block
|
|
||||||
l_ij = tl.sum(p, 1) # [BLOCK_M]
|
|
||||||
|
|
||||||
# Update running sum: l_new = l_i * alpha + l_ij
|
|
||||||
l_new = l_i * alpha + l_ij
|
|
||||||
|
|
||||||
# Rescale previous output and add new contribution
|
|
||||||
acc_o = acc_o * alpha[:, None]
|
|
||||||
|
|
||||||
# Load V
|
|
||||||
if EVEN_N & EVEN_M:
|
|
||||||
if EVEN_HEADDIM:
|
|
||||||
v = tl.load(v_ptrs + start_n * stride_vn)
|
|
||||||
else:
|
|
||||||
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
|
|
||||||
else:
|
|
||||||
if EVEN_HEADDIM:
|
|
||||||
v = tl.load(
|
|
||||||
v_ptrs + start_n * stride_vn,
|
|
||||||
mask=(start_n + offs_n)[:, None] < seqlen_k,
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
v = tl.load(
|
|
||||||
v_ptrs + start_n * stride_vn,
|
|
||||||
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# acc_o += P @ V
|
|
||||||
p = p.to(v.dtype)
|
|
||||||
acc_o += tl.dot(p, v)
|
|
||||||
|
|
||||||
# Update running statistics
|
|
||||||
m_i = m_new
|
|
||||||
l_i = l_new
|
|
||||||
|
|
||||||
# Final normalization: output = acc_o / l_i
|
|
||||||
acc_o = acc_o / l_i[:, None]
|
|
||||||
|
|
||||||
# Compute LSE = m_i + log(l_i)
|
|
||||||
lse_i = m_i + tl.log(l_i)
|
|
||||||
|
|
||||||
# Store LSE
|
|
||||||
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
|
|
||||||
if EVEN_M:
|
|
||||||
tl.store(lse_ptrs, lse_i)
|
|
||||||
else:
|
|
||||||
tl.store(lse_ptrs, lse_i, mask=offs_m < seqlen_q)
|
|
||||||
|
|
||||||
# Store output
|
|
||||||
out_ptrs = (
|
|
||||||
Out
|
|
||||||
+ off_b * stride_ob
|
|
||||||
+ off_h * stride_oh
|
|
||||||
+ (offs_m[:, None] * stride_om + offs_d[None, :])
|
|
||||||
)
|
|
||||||
if EVEN_M:
|
|
||||||
if EVEN_HEADDIM:
|
|
||||||
tl.store(out_ptrs, acc_o)
|
|
||||||
else:
|
|
||||||
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
|
|
||||||
else:
|
|
||||||
if EVEN_HEADDIM:
|
|
||||||
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
|
|
||||||
else:
|
|
||||||
tl.store(
|
|
||||||
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def flash_attn_with_lse(
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
softmax_scale: Optional[float] = None,
|
|
||||||
causal: bool = False,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Flash attention forward pass that returns both output and LSE.
|
|
||||||
|
|
||||||
Uses flash_attn library which natively supports GQA without memory overhead.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [batch, seqlen_q, nheads_q, headdim]
|
|
||||||
k: Key tensor [batch, seqlen_k, nheads_kv, headdim]
|
|
||||||
v: Value tensor [batch, seqlen_k, nheads_kv, headdim]
|
|
||||||
softmax_scale: Scaling factor (default: 1/sqrt(headdim))
|
|
||||||
causal: Whether to apply causal masking
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
out: Output tensor [batch, seqlen_q, nheads_q, headdim]
|
|
||||||
lse: Log-sum-exp tensor [batch, nheads_q, seqlen_q]
|
|
||||||
"""
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func
|
|
||||||
|
|
||||||
batch, seqlen_q, nheads_q, headdim = q.shape
|
|
||||||
_, seqlen_k, nheads_kv, _ = k.shape
|
|
||||||
|
|
||||||
if softmax_scale is None:
|
|
||||||
softmax_scale = 1.0 / math.sqrt(headdim)
|
|
||||||
|
|
||||||
# Use flash_attn_func which natively supports GQA (no memory overhead)
|
|
||||||
# It returns (output, softmax_lse) when return_attn_probs=True is not set
|
|
||||||
# We need to use the internal function to get LSE
|
|
||||||
out, lse, _ = flash_attn_func(
|
|
||||||
q, k, v,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=causal,
|
|
||||||
return_attn_probs=True, # This makes it return (out, softmax_lse, S_dmask)
|
|
||||||
)
|
|
||||||
|
|
||||||
# lse shape from flash_attn: [batch, nheads_q, seqlen_q_rounded]
|
|
||||||
# Trim to actual seqlen_q
|
|
||||||
lse = lse[:, :, :seqlen_q]
|
|
||||||
|
|
||||||
return out, lse
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _merge_lse_kernel(
|
|
||||||
lse1_ptr, lse2_ptr, lse_out_ptr,
|
|
||||||
num_elements: tl.constexpr,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""Fused kernel for merging LSE values.
|
|
||||||
|
|
||||||
IMPORTANT: Uses fp32 for exp/log operations to avoid precision loss.
|
|
||||||
bf16 has only 7 bits of mantissa, causing significant errors in exp/log.
|
|
||||||
"""
|
|
||||||
# Each program handles BLOCK_SIZE elements
|
|
||||||
pid = tl.program_id(0)
|
|
||||||
block_start = pid * BLOCK_SIZE
|
|
||||||
|
|
||||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = offsets < num_elements
|
|
||||||
|
|
||||||
# Load lse values and convert to fp32 for precision
|
|
||||||
lse1 = tl.load(lse1_ptr + offsets, mask=mask).to(tl.float32)
|
|
||||||
lse2 = tl.load(lse2_ptr + offsets, mask=mask).to(tl.float32)
|
|
||||||
|
|
||||||
# Compute max for numerical stability (in fp32)
|
|
||||||
max_lse = tl.maximum(lse1, lse2)
|
|
||||||
|
|
||||||
# Compute exp(lse - max_lse) in fp32
|
|
||||||
exp1 = tl.exp(lse1 - max_lse)
|
|
||||||
exp2 = tl.exp(lse2 - max_lse)
|
|
||||||
|
|
||||||
# Compute merged LSE: max_lse + log(exp1 + exp2) in fp32
|
|
||||||
lse_merged = max_lse + tl.log(exp1 + exp2)
|
|
||||||
|
|
||||||
# Store result (convert back to original dtype)
|
|
||||||
tl.store(lse_out_ptr + offsets, lse_merged, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _merge_output_kernel(
|
|
||||||
o1_ptr, o2_ptr, lse1_ptr, lse2_ptr, o_out_ptr,
|
|
||||||
batch: tl.constexpr, seqlen_q: tl.constexpr, nheads: tl.constexpr, headdim: tl.constexpr,
|
|
||||||
BLOCK_SIZE: tl.constexpr,
|
|
||||||
):
|
|
||||||
"""Fused kernel for merging attention outputs.
|
|
||||||
|
|
||||||
IMPORTANT: Uses fp32 for exp operations and weighted sum to avoid precision loss.
|
|
||||||
This is critical for numerical accuracy in chunked attention.
|
|
||||||
"""
|
|
||||||
# Each program handles BLOCK_SIZE elements along headdim for one (batch, seqlen_q, nheads) position
|
|
||||||
pid_batch = tl.program_id(0)
|
|
||||||
pid_seq = tl.program_id(1)
|
|
||||||
pid_head = tl.program_id(2)
|
|
||||||
|
|
||||||
# Compute LSE index: [batch, nheads, seqlen_q]
|
|
||||||
lse_idx = pid_batch * nheads * seqlen_q + pid_head * seqlen_q + pid_seq
|
|
||||||
|
|
||||||
# Load LSE values and convert to fp32 for precision
|
|
||||||
lse1 = tl.load(lse1_ptr + lse_idx).to(tl.float32)
|
|
||||||
lse2 = tl.load(lse2_ptr + lse_idx).to(tl.float32)
|
|
||||||
|
|
||||||
# Compute max and scaling factors in fp32
|
|
||||||
max_lse = tl.maximum(lse1, lse2)
|
|
||||||
exp1 = tl.exp(lse1 - max_lse)
|
|
||||||
exp2 = tl.exp(lse2 - max_lse)
|
|
||||||
sum_exp = exp1 + exp2
|
|
||||||
|
|
||||||
# Process headdim in chunks
|
|
||||||
for d_offset in range(0, headdim, BLOCK_SIZE):
|
|
||||||
d_idx = d_offset + tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = d_idx < headdim
|
|
||||||
|
|
||||||
# Compute output index: [batch, seqlen_q, nheads, headdim]
|
|
||||||
base_idx = (pid_batch * seqlen_q * nheads * headdim +
|
|
||||||
pid_seq * nheads * headdim +
|
|
||||||
pid_head * headdim)
|
|
||||||
o_idx = base_idx + d_idx
|
|
||||||
|
|
||||||
# Load o1, o2 and convert to fp32 for weighted sum
|
|
||||||
o1_val = tl.load(o1_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
|
||||||
o2_val = tl.load(o2_ptr + o_idx, mask=mask, other=0.0).to(tl.float32)
|
|
||||||
|
|
||||||
# Compute merged output in fp32: (o1 * exp1 + o2 * exp2) / sum_exp
|
|
||||||
o_merged = (o1_val * exp1 + o2_val * exp2) / sum_exp
|
|
||||||
|
|
||||||
# Store result (Triton will convert back to original dtype)
|
|
||||||
tl.store(o_out_ptr + o_idx, o_merged, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def merge_attention_outputs(
|
|
||||||
o1: torch.Tensor,
|
|
||||||
lse1: torch.Tensor,
|
|
||||||
o2: torch.Tensor,
|
|
||||||
lse2: torch.Tensor,
|
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Merge two attention outputs using online softmax (Triton fused kernel).
|
|
||||||
|
|
||||||
This implements the online softmax merging formula:
|
|
||||||
- m_new = max(lse1, lse2)
|
|
||||||
- o_new = (exp(lse1 - m_new) * o1 + exp(lse2 - m_new) * o2) / (exp(lse1 - m_new) + exp(lse2 - m_new))
|
|
||||||
- lse_new = m_new + log(exp(lse1 - m_new) + exp(lse2 - m_new))
|
|
||||||
|
|
||||||
Args:
|
|
||||||
o1: First output [batch, seqlen_q, nheads, headdim]
|
|
||||||
lse1: First LSE [batch, nheads, seqlen_q]
|
|
||||||
o2: Second output [batch, seqlen_q, nheads, headdim]
|
|
||||||
lse2: Second LSE [batch, nheads, seqlen_q]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
o_merged: Merged output [batch, seqlen_q, nheads, headdim]
|
|
||||||
lse_merged: Merged LSE [batch, nheads, seqlen_q]
|
|
||||||
"""
|
|
||||||
batch, seqlen_q, nheads, headdim = o1.shape
|
|
||||||
|
|
||||||
# Allocate output tensors
|
|
||||||
o_merged = torch.empty_like(o1)
|
|
||||||
lse_merged = torch.empty_like(lse1)
|
|
||||||
|
|
||||||
# Launch LSE merge kernel
|
|
||||||
num_lse_elements = batch * nheads * seqlen_q
|
|
||||||
BLOCK_SIZE_LSE = 256
|
|
||||||
grid_lse = (triton.cdiv(num_lse_elements, BLOCK_SIZE_LSE),)
|
|
||||||
_merge_lse_kernel[grid_lse](
|
|
||||||
lse1, lse2, lse_merged,
|
|
||||||
num_lse_elements,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE_LSE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Launch output merge kernel
|
|
||||||
BLOCK_SIZE = 128
|
|
||||||
grid_output = (batch, seqlen_q, nheads)
|
|
||||||
_merge_output_kernel[grid_output](
|
|
||||||
o1, o2, lse1, lse2, o_merged,
|
|
||||||
batch, seqlen_q, nheads, headdim,
|
|
||||||
BLOCK_SIZE=BLOCK_SIZE,
|
|
||||||
)
|
|
||||||
|
|
||||||
return o_merged, lse_merged
|
|
||||||
|
|
||||||
|
|
||||||
def chunked_attention_varlen(
|
|
||||||
q: torch.Tensor,
|
|
||||||
kv_chunks: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
||||||
cu_seqlens_q: torch.Tensor,
|
|
||||||
cu_seqlens_k_list: List[torch.Tensor],
|
|
||||||
max_seqlen_q: int,
|
|
||||||
max_seqlen_k_list: List[int],
|
|
||||||
softmax_scale: Optional[float] = None,
|
|
||||||
causal_mask_per_chunk: Optional[List[bool]] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute attention with KV split across multiple chunks.
|
|
||||||
|
|
||||||
This is the core function for chunked prefill. It computes attention
|
|
||||||
against each KV chunk and merges results using online softmax.
|
|
||||||
|
|
||||||
For causal attention with chunked KV:
|
|
||||||
- First chunk (current tokens): Apply causal mask
|
|
||||||
- Previous chunks: No causal mask (all previous tokens are valid context)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [total_q_tokens, nheads, headdim]
|
|
||||||
kv_chunks: List of (K, V) tuples, each [batch, seqlen_k_i, nheads, headdim]
|
|
||||||
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
|
||||||
cu_seqlens_k_list: List of cumulative sequence lengths for each KV chunk
|
|
||||||
max_seqlen_q: Maximum query sequence length
|
|
||||||
max_seqlen_k_list: List of maximum key sequence lengths for each chunk
|
|
||||||
softmax_scale: Scaling factor
|
|
||||||
causal_mask_per_chunk: Whether to apply causal mask for each chunk
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
out: Output tensor [total_q_tokens, nheads, headdim]
|
|
||||||
"""
|
|
||||||
if len(kv_chunks) == 0:
|
|
||||||
raise ValueError("Need at least one KV chunk")
|
|
||||||
|
|
||||||
nheads = q.shape[1]
|
|
||||||
headdim = q.shape[2]
|
|
||||||
batch = cu_seqlens_q.shape[0] - 1
|
|
||||||
|
|
||||||
if softmax_scale is None:
|
|
||||||
softmax_scale = 1.0 / math.sqrt(headdim)
|
|
||||||
|
|
||||||
if causal_mask_per_chunk is None:
|
|
||||||
# Default: causal for last chunk only
|
|
||||||
causal_mask_per_chunk = [False] * (len(kv_chunks) - 1) + [True]
|
|
||||||
|
|
||||||
# Initialize accumulated output and LSE
|
|
||||||
accumulated_o = None
|
|
||||||
accumulated_lse = None
|
|
||||||
|
|
||||||
for chunk_idx, (k_chunk, v_chunk) in enumerate(kv_chunks):
|
|
||||||
is_causal = causal_mask_per_chunk[chunk_idx]
|
|
||||||
|
|
||||||
# Reshape Q for batch processing
|
|
||||||
# For varlen, we need to handle each sequence separately
|
|
||||||
# For simplicity, assume single sequence (batch=1) for now
|
|
||||||
q_batched = q.unsqueeze(0) # [1, total_q, nheads, headdim]
|
|
||||||
|
|
||||||
# Compute attention for this chunk
|
|
||||||
chunk_o, chunk_lse = flash_attn_with_lse(
|
|
||||||
q_batched,
|
|
||||||
k_chunk,
|
|
||||||
v_chunk,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=is_causal,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Merge with accumulated
|
|
||||||
if accumulated_o is None:
|
|
||||||
accumulated_o = chunk_o
|
|
||||||
accumulated_lse = chunk_lse
|
|
||||||
else:
|
|
||||||
accumulated_o, accumulated_lse = merge_attention_outputs(
|
|
||||||
accumulated_o, accumulated_lse,
|
|
||||||
chunk_o, chunk_lse,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove batch dimension
|
|
||||||
return accumulated_o.squeeze(0)
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkedPrefillState:
|
|
||||||
"""
|
|
||||||
State for tracking chunked prefill progress.
|
|
||||||
|
|
||||||
This class maintains the accumulated attention output and LSE
|
|
||||||
across multiple prefill chunks.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, num_layers: int, dtype: torch.dtype, device: torch.device):
|
|
||||||
self.num_layers = num_layers
|
|
||||||
self.dtype = dtype
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
# Per-layer accumulated outputs
|
|
||||||
# Each entry: (accumulated_output, accumulated_lse) or None
|
|
||||||
self.layer_states: List[Optional[Tuple[torch.Tensor, torch.Tensor]]] = [
|
|
||||||
None for _ in range(num_layers)
|
|
||||||
]
|
|
||||||
|
|
||||||
# Track which chunks have been processed
|
|
||||||
self.processed_chunks: int = 0
|
|
||||||
|
|
||||||
def update_layer(
|
|
||||||
self,
|
|
||||||
layer_id: int,
|
|
||||||
chunk_output: torch.Tensor,
|
|
||||||
chunk_lse: torch.Tensor,
|
|
||||||
):
|
|
||||||
"""Update accumulated state for a layer with a new chunk's output."""
|
|
||||||
if self.layer_states[layer_id] is None:
|
|
||||||
self.layer_states[layer_id] = (chunk_output, chunk_lse)
|
|
||||||
else:
|
|
||||||
acc_o, acc_lse = self.layer_states[layer_id]
|
|
||||||
merged_o, merged_lse = merge_attention_outputs(
|
|
||||||
acc_o, acc_lse,
|
|
||||||
chunk_output, chunk_lse,
|
|
||||||
)
|
|
||||||
self.layer_states[layer_id] = (merged_o, merged_lse)
|
|
||||||
|
|
||||||
def get_layer_output(self, layer_id: int) -> Optional[torch.Tensor]:
|
|
||||||
"""Get the final accumulated output for a layer."""
|
|
||||||
if self.layer_states[layer_id] is None:
|
|
||||||
return None
|
|
||||||
return self.layer_states[layer_id][0]
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
"""Clear all accumulated state."""
|
|
||||||
self.layer_states = [None for _ in range(self.num_layers)]
|
|
||||||
self.processed_chunks = 0
|
|
||||||
|
|
||||||
|
|
||||||
# Test function
|
|
||||||
def _test_chunked_attention():
|
|
||||||
"""Test chunked attention using flash_attn_with_lse and merge_attention_outputs."""
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_func
|
|
||||||
|
|
||||||
torch.manual_seed(42)
|
|
||||||
|
|
||||||
print("=" * 70)
|
|
||||||
print("Test: Chunked attention vs flash_attn_func (non-causal)")
|
|
||||||
print("=" * 70)
|
|
||||||
print("Splitting K,V into chunks, computing attention per chunk, then merging")
|
|
||||||
print()
|
|
||||||
|
|
||||||
for dtype in [torch.float16, torch.bfloat16]:
|
|
||||||
for num_chunks in [64, 128, 256]:
|
|
||||||
for batch, seqlen, nheads, headdim in [
|
|
||||||
(1, 1024, 32, 128),
|
|
||||||
(1, 2048, 32, 128),
|
|
||||||
(1, 4096, 32, 128),
|
|
||||||
(1, 8192, 32, 128),
|
|
||||||
]:
|
|
||||||
# Generate random Q, K, V
|
|
||||||
q = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
|
||||||
k = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
|
||||||
v = torch.randn(batch, seqlen, nheads, headdim, device="cuda", dtype=dtype)
|
|
||||||
|
|
||||||
# Reference: full attention (non-causal)
|
|
||||||
out_ref = flash_attn_func(q, k, v, causal=False)
|
|
||||||
|
|
||||||
# Chunked attention: split K, V into chunks
|
|
||||||
chunk_size = seqlen // num_chunks
|
|
||||||
accumulated_o = None
|
|
||||||
accumulated_lse = None
|
|
||||||
|
|
||||||
for i in range(num_chunks):
|
|
||||||
start = i * chunk_size
|
|
||||||
end = (i + 1) * chunk_size
|
|
||||||
|
|
||||||
k_chunk = k[:, start:end, :, :]
|
|
||||||
v_chunk = v[:, start:end, :, :]
|
|
||||||
|
|
||||||
# Q attends to this K,V chunk (non-causal)
|
|
||||||
chunk_o, chunk_lse = flash_attn_with_lse(
|
|
||||||
q, k_chunk, v_chunk, causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if accumulated_o is None:
|
|
||||||
accumulated_o = chunk_o
|
|
||||||
accumulated_lse = chunk_lse
|
|
||||||
else:
|
|
||||||
# Merge with previous chunks
|
|
||||||
accumulated_o, accumulated_lse = merge_attention_outputs(
|
|
||||||
accumulated_o, accumulated_lse,
|
|
||||||
chunk_o, chunk_lse
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compare
|
|
||||||
out_diff = (out_ref - accumulated_o).abs()
|
|
||||||
out_max_diff = out_diff.max().item()
|
|
||||||
out_mean_diff = out_diff.mean().item()
|
|
||||||
|
|
||||||
status = "PASS" if out_max_diff < 1e-2 else "FAIL"
|
|
||||||
print(
|
|
||||||
f"[{status}] dtype={str(dtype):14s} chunks={num_chunks} "
|
|
||||||
f"shape=({batch}, {seqlen:4d}, {nheads:2d}, {headdim:3d}) "
|
|
||||||
f"max_diff={out_max_diff:.6f} mean_diff={out_mean_diff:.6f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print()
|
|
||||||
print("=" * 70)
|
|
||||||
print("Test completed!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
_test_chunked_attention()
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -14,9 +14,9 @@ class Context:
|
|||||||
context_lens: torch.Tensor | None = None
|
context_lens: torch.Tensor | None = None
|
||||||
block_tables: torch.Tensor | None = None
|
block_tables: torch.Tensor | None = None
|
||||||
|
|
||||||
# Attention policy support (GPU-only path)
|
# Sparse prefill attention support (GPU-only path)
|
||||||
# When set, uses policy.compute_prefill() instead of FlashAttention
|
# When set, uses policy.sparse_prefill_attention() instead of FlashAttention
|
||||||
attention_policy: Any = None # AttentionPolicy instance
|
sparse_prefill_policy: Any = None # SparsePolicy instance with supports_prefill=True
|
||||||
|
|
||||||
|
|
||||||
_CONTEXT = Context()
|
_CONTEXT = Context()
|
||||||
@@ -35,7 +35,7 @@ def set_context(
|
|||||||
slot_mapping=None,
|
slot_mapping=None,
|
||||||
context_lens=None,
|
context_lens=None,
|
||||||
block_tables=None,
|
block_tables=None,
|
||||||
attention_policy=None,
|
sparse_prefill_policy=None,
|
||||||
):
|
):
|
||||||
global _CONTEXT
|
global _CONTEXT
|
||||||
_CONTEXT = Context(
|
_CONTEXT = Context(
|
||||||
@@ -47,7 +47,7 @@ def set_context(
|
|||||||
slot_mapping=slot_mapping,
|
slot_mapping=slot_mapping,
|
||||||
context_lens=context_lens,
|
context_lens=context_lens,
|
||||||
block_tables=block_tables,
|
block_tables=block_tables,
|
||||||
attention_policy=attention_policy,
|
sparse_prefill_policy=sparse_prefill_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
130
notes.md
130
notes.md
@@ -1,130 +0,0 @@
|
|||||||
# Notes: SparsePolicy Refactoring Research
|
|
||||||
|
|
||||||
## Sources
|
|
||||||
|
|
||||||
### Source 1: tzj/minference branch - policy.py
|
|
||||||
- 路径: `nanovllm/kvcache/sparse/policy.py`
|
|
||||||
- 关键设计:
|
|
||||||
- `PolicyContext` 数据类包含 query_chunk_idx, num_query_chunks, layer_id, query, is_prefill 等
|
|
||||||
- `select_blocks()` 需要 offload_engine 参数
|
|
||||||
- `compute_chunked_prefill()` 和 `compute_chunked_decode()` 是完整的 attention 流程
|
|
||||||
- `on_prefill_offload()` / `on_decode_offload()` hooks 用于收集元数据
|
|
||||||
|
|
||||||
### Source 2: tzj/minference branch - full_policy.py
|
|
||||||
- 路径: `nanovllm/kvcache/sparse/full_policy.py`
|
|
||||||
- 关键实现:
|
|
||||||
- `compute_chunked_prefill()` 内部使用 ring buffer pipeline 加载 blocks
|
|
||||||
- 使用 `flash_attn_with_lse` 和 `merge_attention_outputs` 合并多个 chunk 的 attention
|
|
||||||
- `compute_chunked_decode()` 处理 prefilled blocks + decode buffer
|
|
||||||
|
|
||||||
### Source 3: tzj/layer-offload branch - model_runner.py
|
|
||||||
- 路径: `nanovllm/engine/model_runner.py`
|
|
||||||
- 关键设计:
|
|
||||||
- `run_layerwise_offload_prefill()` 逐层处理,每层计算完整 attention
|
|
||||||
- `sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)` 简单接口
|
|
||||||
- FULL policy 通过 `if sparse_prefill_policy is None` 走 else 分支
|
|
||||||
|
|
||||||
### Source 4: tzj/layer-offload branch - xattn.py
|
|
||||||
- 路径: `nanovllm/kvcache/sparse/xattn.py`
|
|
||||||
- 关键实现:
|
|
||||||
- `sparse_prefill_attention()` 直接使用 FlashAttention(因为 chunked prefill 架构限制)
|
|
||||||
- 保留 Triton kernels 供未来 GPU-only 模式
|
|
||||||
|
|
||||||
## Synthesized Findings
|
|
||||||
|
|
||||||
### 架构差异总结
|
|
||||||
|
|
||||||
| 方面 | Chunked Offload | Layerwise Offload |
|
|
||||||
|------|-----------------|-------------------|
|
|
||||||
| **Prefill 流程** | chunk-by-chunk,跨层 | layer-by-layer,完整序列 |
|
|
||||||
| **KV 存储** | 每 chunk 立即 offload | 每层计算后 offload |
|
|
||||||
| **Attention 计算** | 分多次计算+合并 | 一次完整计算 |
|
|
||||||
| **Block 加载** | 需要从 CPU 加载历史 | 不需要,已在 GPU |
|
|
||||||
| **Policy 责任** | 完整 attention 流程 | 仅 attention kernel 选择 |
|
|
||||||
|
|
||||||
### Layerwise Offload 的简化点
|
|
||||||
|
|
||||||
1. **不需要 block selection**: 整层 KV 都在 GPU,无需选择
|
|
||||||
2. **不需要 offload_engine 参数**: Policy 不负责加载 KV
|
|
||||||
3. **不需要 merge_attention_outputs**: 一次计算完整 attention
|
|
||||||
4. **不需要 offload hooks**: offload 在 model_runner 统一处理
|
|
||||||
|
|
||||||
### 设计建议
|
|
||||||
|
|
||||||
1. **保持接口简单**: 只需要 `compute_prefill_attention()` 和 `compute_decode_attention()`
|
|
||||||
2. **FULL 也实现方法**: 不再通过 `is None` 判断,所有 policy 统一调用
|
|
||||||
3. **移除不必要的参数**: 不需要 offload_engine, kvcache_manager, seq 等
|
|
||||||
4. **统一命名**: 使用 `compute_*_attention` 而不是 `sparse_prefill_attention`
|
|
||||||
|
|
||||||
## Code Examples
|
|
||||||
|
|
||||||
### 当前调用方式 (model_runner.py:876-891)
|
|
||||||
|
|
||||||
```python
|
|
||||||
# Sparse or Full attention
|
|
||||||
if self.sparse_prefill_policy is not None:
|
|
||||||
# MInference or other sparse prefill policy
|
|
||||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(
|
|
||||||
q, k, v, layer_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Full attention using FlashAttention
|
|
||||||
attn_output = flash_attn_varlen_func(
|
|
||||||
q, k, v, ...
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 建议的新调用方式
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 所有 policy 统一调用
|
|
||||||
attn_output = self.attention_policy.compute_prefill_attention(
|
|
||||||
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Questions Resolved
|
|
||||||
|
|
||||||
- Q: 是否需要 PolicyContext?
|
|
||||||
- A: 可以简化,因为 layerwise 模式下不需要 chunk 信息
|
|
||||||
|
|
||||||
- Q: decode 阶段如何处理?
|
|
||||||
- A: **Decode 不需要 policy**!当前 `run_layerwise_offload_decode()` 使用标准 `layer(positions, hidden_states, residual)` 调用,走 Attention.forward() 路径
|
|
||||||
|
|
||||||
- Q: 为什么 decode 不需要 sparse?
|
|
||||||
- A: 因为 decode 每次只有 1 个 token,没有稀疏化的意义。KV 从 ring buffer 加载后直接用 flash_attn_with_kvcache
|
|
||||||
|
|
||||||
## Key Insight
|
|
||||||
|
|
||||||
**Layerwise Offload 的 Policy 设计应该只关注 Prefill**:
|
|
||||||
|
|
||||||
```
|
|
||||||
Prefill: 需要 Policy
|
|
||||||
- 整个序列一次计算 attention
|
|
||||||
- 可以使用 sparse attention 方法(如 MInference 的 vertical+slash pattern)
|
|
||||||
- Policy 接收 q, k, v, layer_id, softmax_scale
|
|
||||||
|
|
||||||
Decode: 不需要 Policy
|
|
||||||
- 每次只有 1 个 token query
|
|
||||||
- KV 从 ring buffer 加载
|
|
||||||
- 使用标准 flash_attn_with_kvcache
|
|
||||||
```
|
|
||||||
|
|
||||||
## Interface Comparison Summary
|
|
||||||
|
|
||||||
| 方面 | tzj/minference | tzj/layer-offload (新设计) |
|
|
||||||
|------|----------------|---------------------------|
|
|
||||||
| 类名 | SparsePolicy | AttentionPolicy |
|
|
||||||
| Prefill 方法 | compute_chunked_prefill() | compute_attention() |
|
|
||||||
| Decode 方法 | compute_chunked_decode() | 不需要(用标准路径) |
|
|
||||||
| 需要 offload_engine | 是 | 否 |
|
|
||||||
| 需要 kvcache_manager | 是 | 否 |
|
|
||||||
| 需要 seq | 是 | 否 |
|
|
||||||
| 支持 FULL | 是 | 是 |
|
|
||||||
|
|
||||||
## Migration Path
|
|
||||||
|
|
||||||
1. 保留 `SparsePolicy` 作为 `AttentionPolicy` 的别名
|
|
||||||
2. 保留 `PolicyContext` 供未来扩展
|
|
||||||
3. 保留 `select_blocks()` 方法签名(虽然不使用)
|
|
||||||
4. 移除 `requires_block_selection` 属性(不需要)
|
|
||||||
549
task_plan.md
549
task_plan.md
@@ -1,549 +0,0 @@
|
|||||||
# Task Plan: Refactor SparsePolicy for Layerwise Offload
|
|
||||||
|
|
||||||
## Goal
|
|
||||||
重构 SparsePolicy 接口,参考 tzj/minference 分支的设计模式,使所有 attention 都可以抽象成 policy,并按统一规范编写。适配当前 layerwise offload 架构特点(整层 KV 在 GPU 上)。
|
|
||||||
|
|
||||||
## Background
|
|
||||||
|
|
||||||
### 两种 Offload 架构对比
|
|
||||||
|
|
||||||
| 特性 | tzj/minference (Chunked Offload) | tzj/layer-offload (Layerwise Offload) |
|
|
||||||
|------|----------------------------------|---------------------------------------|
|
|
||||||
| 处理粒度 | 每次一个 chunk (block_size tokens) | 每次一整层 (所有 tokens) |
|
|
||||||
| KV 位置 | 历史 chunks 在 CPU,需要加载 | 整层 KV 都在 GPU |
|
|
||||||
| Policy 入口 | `compute_chunked_prefill()/decode()` | `compute_prefill()/decode()` |
|
|
||||||
| 需要 offload_engine | 是(加载 blocks) | 否(KV 已在 GPU) |
|
|
||||||
| Mask 计算 | `select_blocks()` 返回 block IDs | `estimate()` 返回 sparse mask |
|
|
||||||
|
|
||||||
### tzj/minference 的 Policy 接口
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SparsePolicy(ABC):
|
|
||||||
supports_prefill: bool
|
|
||||||
supports_decode: bool
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def select_blocks(self, available_blocks, offload_engine, ctx) -> List[int]
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute_chunked_prefill(self, q, k, v, layer_id, ..., offload_engine, ...) -> Tensor
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute_chunked_decode(self, q, layer_id, ..., offload_engine, ...) -> Tensor
|
|
||||||
```
|
|
||||||
|
|
||||||
### 当前 branch 的 Policy 接口(重构前)
|
|
||||||
|
|
||||||
```python
|
|
||||||
class SparsePolicy(ABC):
|
|
||||||
supports_prefill: bool
|
|
||||||
supports_decode: bool
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def select_blocks(self, available_blocks, ctx) -> List[int]
|
|
||||||
|
|
||||||
def sparse_prefill_attention(self, q, k, v, layer_id) -> Tensor
|
|
||||||
```
|
|
||||||
|
|
||||||
## Phases
|
|
||||||
|
|
||||||
- [x] Phase 1: 分析差异并设计新接口
|
|
||||||
- [x] **Phase 0: 创建 nanovllm.ops 模块** ✅ 测试通过
|
|
||||||
- [ ] Phase 2: 重构 AttentionPolicy 基类
|
|
||||||
- [ ] Phase 3: 重构 FullAttentionPolicy
|
|
||||||
- [ ] Phase 4: 重构 XAttentionPolicy (含 estimate 方法)
|
|
||||||
- [ ] Phase 5: 更新 model_runner 调用方式
|
|
||||||
- [ ] Phase 6: 测试验证
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Phase 0: 创建 nanovllm.ops 模块
|
|
||||||
|
|
||||||
### 目标
|
|
||||||
从 tzj/minference 分支提取 ops 模块,为 XAttention estimate 提供底层算子支持。
|
|
||||||
|
|
||||||
### 步骤
|
|
||||||
|
|
||||||
1. **创建目录结构**
|
|
||||||
```
|
|
||||||
nanovllm/ops/
|
|
||||||
├── __init__.py
|
|
||||||
├── xattn.py # xattn_estimate, xattn_estimate_chunked, Triton kernels
|
|
||||||
└── chunked_attention.py # flash_attn_with_lse, merge_attention_outputs (备用)
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **从 tzj/minference 提取文件**
|
|
||||||
```bash
|
|
||||||
git show tzj/minference:nanovllm/ops/__init__.py > nanovllm/ops/__init__.py
|
|
||||||
git show tzj/minference:nanovllm/ops/xattn.py > nanovllm/ops/xattn.py
|
|
||||||
git show tzj/minference:nanovllm/ops/chunked_attention.py > nanovllm/ops/chunked_attention.py
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Cherry-pick 测试文件**
|
|
||||||
```bash
|
|
||||||
git show tzj/minference:tests/test_xattn_estimate_chunked.py > tests/test_xattn_estimate_chunked.py
|
|
||||||
```
|
|
||||||
|
|
||||||
4. **运行测试验证**
|
|
||||||
```bash
|
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/Worktree/nano-vllm:$PYTHONPATH \
|
|
||||||
python tests/test_xattn_estimate_chunked.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### nanovllm/ops 模块内容
|
|
||||||
|
|
||||||
| 文件 | 核心函数 | 用途 |
|
|
||||||
|------|----------|------|
|
|
||||||
| `xattn.py` | `xattn_estimate()` | 标准 XAttention estimation |
|
|
||||||
| `xattn.py` | `xattn_estimate_chunked()` | Chunked prefill 版本 |
|
|
||||||
| `xattn.py` | `flat_group_gemm_fuse_reshape()` | Triton kernel: fused reshape + GEMM |
|
|
||||||
| `xattn.py` | `softmax_fuse_block_sum()` | Triton kernel: softmax + block sum |
|
|
||||||
| `xattn.py` | `find_blocks_chunked()` | Block selection based on threshold |
|
|
||||||
| `chunked_attention.py` | `flash_attn_with_lse()` | Flash attention with LSE output |
|
|
||||||
| `chunked_attention.py` | `merge_attention_outputs()` | Merge multiple attention chunks |
|
|
||||||
|
|
||||||
### 与 Policy 的关系
|
|
||||||
|
|
||||||
```
|
|
||||||
XAttentionPolicy.estimate()
|
|
||||||
└── 调用 nanovllm.ops.xattn.xattn_estimate()
|
|
||||||
├── flat_group_gemm_fuse_reshape() (Triton)
|
|
||||||
├── softmax_fuse_block_sum() (Triton)
|
|
||||||
└── find_blocks_chunked()
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Key Questions
|
|
||||||
|
|
||||||
1. **`select_blocks` 改为什么?**
|
|
||||||
- 改名为 `estimate()`:用于计算 sparse mask
|
|
||||||
- 对于 XAttention,对应 COMPASS 的 `xattn_estimate()` 函数
|
|
||||||
- FullAttentionPolicy 的 `estimate()` 返回 None(表示 full attention)
|
|
||||||
|
|
||||||
2. **Policy 接口应该如何设计?**
|
|
||||||
- Prefill: `compute_prefill(q, k, v, layer_id, softmax_scale)`
|
|
||||||
- Decode: `compute_decode(q, k, v, layer_id, softmax_scale)`
|
|
||||||
- Estimate: `estimate(q, k, layer_id)` - 计算 sparse mask
|
|
||||||
|
|
||||||
3. **FULL policy 如何处理?**
|
|
||||||
- FULL 也实现 `compute_prefill/decode`,使用 FlashAttention
|
|
||||||
- `estimate()` 返回 None(表示不进行稀疏化)
|
|
||||||
|
|
||||||
## Proposed New Interface
|
|
||||||
|
|
||||||
```python
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionPolicy(ABC):
|
|
||||||
"""Layerwise Offload 模式下的 Attention Policy
|
|
||||||
|
|
||||||
所有 attention 计算都通过 policy 进行,包括 Full 和 Sparse。
|
|
||||||
支持 prefill 和 decode 两个阶段。
|
|
||||||
"""
|
|
||||||
|
|
||||||
supports_prefill: bool = True
|
|
||||||
supports_decode: bool = True
|
|
||||||
|
|
||||||
def estimate(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
|
||||||
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: int,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
估算 sparse attention mask。
|
|
||||||
|
|
||||||
对于 sparse policy(如 XAttention),计算哪些 blocks 需要 attend。
|
|
||||||
对于 full policy,返回 None 表示使用完整 attention。
|
|
||||||
|
|
||||||
对应 COMPASS 的 xattn_estimate() 函数。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Transformer layer index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask, 或 None
|
|
||||||
"""
|
|
||||||
return None # 默认为 full attention
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute_prefill(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor, # [seq_len, num_heads, head_dim]
|
|
||||||
k: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: torch.Tensor, # [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: int,
|
|
||||||
softmax_scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
计算 prefill attention。
|
|
||||||
|
|
||||||
整层 KV 都在 GPU 上,一次计算完整 attention。
|
|
||||||
可以先调用 estimate() 获取 sparse mask,然后应用 block sparse attention。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [seq_len, num_heads, head_dim]
|
|
||||||
k: Key tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: Transformer layer index
|
|
||||||
softmax_scale: Softmax scaling factor (1/sqrt(head_dim))
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention output [seq_len, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def compute_decode(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor, # [1, num_heads, head_dim]
|
|
||||||
k: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
|
|
||||||
v: torch.Tensor, # [context_len+1, num_kv_heads, head_dim]
|
|
||||||
layer_id: int,
|
|
||||||
softmax_scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
计算 decode attention。
|
|
||||||
|
|
||||||
KV 从 ring buffer 提供,包含 prefill tokens + 已 decode 的 tokens。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: Query tensor [1, num_heads, head_dim]
|
|
||||||
k: Key tensor [context_len+1, num_kv_heads, head_dim]
|
|
||||||
v: Value tensor [context_len+1, num_kv_heads, head_dim]
|
|
||||||
layer_id: Transformer layer index
|
|
||||||
softmax_scale: Softmax scaling factor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Attention output [1, num_heads, head_dim]
|
|
||||||
"""
|
|
||||||
# 默认实现:使用 FlashAttention
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
|
|
||||||
context_len = k.shape[0]
|
|
||||||
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
|
||||||
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
|
||||||
|
|
||||||
return flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=1,
|
|
||||||
max_seqlen_k=context_len,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset policy state between sequences."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"{self.__class__.__name__}()"
|
|
||||||
|
|
||||||
|
|
||||||
# 保留旧名称作为别名
|
|
||||||
SparsePolicy = AttentionPolicy
|
|
||||||
```
|
|
||||||
|
|
||||||
## Implementation Plan
|
|
||||||
|
|
||||||
### Phase 2: 重构 policy.py
|
|
||||||
|
|
||||||
```python
|
|
||||||
# nanovllm/kvcache/sparse/policy.py
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionPolicy(ABC):
|
|
||||||
"""Base class for attention policies in layerwise offload mode."""
|
|
||||||
|
|
||||||
supports_prefill: bool = True
|
|
||||||
supports_decode: bool = True
|
|
||||||
|
|
||||||
def estimate(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
Estimate sparse attention mask.
|
|
||||||
|
|
||||||
For sparse policies (e.g., XAttention), computes block-level importance.
|
|
||||||
For full policy, returns None.
|
|
||||||
|
|
||||||
Corresponds to xattn_estimate() in COMPASS.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sparse_mask: [num_heads, q_blocks, k_blocks] or None
|
|
||||||
"""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def compute_prefill(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
softmax_scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Compute prefill attention."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def compute_decode(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
softmax_scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Compute decode attention (default: FlashAttention)."""
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
|
|
||||||
context_len = k.shape[0]
|
|
||||||
cu_seqlens_q = torch.tensor([0, 1], dtype=torch.int32, device=q.device)
|
|
||||||
cu_seqlens_k = torch.tensor([0, context_len], dtype=torch.int32, device=q.device)
|
|
||||||
|
|
||||||
return flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
|
||||||
max_seqlen_q=1,
|
|
||||||
max_seqlen_k=context_len,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f"{self.__class__.__name__}()"
|
|
||||||
|
|
||||||
|
|
||||||
# Backward compatibility alias
|
|
||||||
SparsePolicy = AttentionPolicy
|
|
||||||
```
|
|
||||||
|
|
||||||
### Phase 3: 重构 FullAttentionPolicy
|
|
||||||
|
|
||||||
```python
|
|
||||||
# nanovllm/kvcache/sparse/full_policy.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from .policy import AttentionPolicy
|
|
||||||
|
|
||||||
|
|
||||||
class FullAttentionPolicy(AttentionPolicy):
|
|
||||||
"""Full attention using FlashAttention (no sparsity)."""
|
|
||||||
|
|
||||||
supports_prefill = True
|
|
||||||
supports_decode = True
|
|
||||||
|
|
||||||
def estimate(self, q, k, layer_id):
|
|
||||||
"""Full attention - no sparse mask needed."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
def compute_prefill(self, q, k, v, layer_id, softmax_scale):
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
|
|
||||||
seq_len = q.shape[0]
|
|
||||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
|
||||||
|
|
||||||
return flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=seq_len,
|
|
||||||
max_seqlen_k=seq_len,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return "FullAttentionPolicy()"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Phase 4: 重构 XAttentionPolicy
|
|
||||||
|
|
||||||
```python
|
|
||||||
# nanovllm/kvcache/sparse/xattn.py
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from typing import Optional
|
|
||||||
from .policy import AttentionPolicy
|
|
||||||
|
|
||||||
|
|
||||||
class XAttentionPolicy(AttentionPolicy):
|
|
||||||
"""
|
|
||||||
XAttention sparse prefill policy.
|
|
||||||
|
|
||||||
Uses chunked estimation to compute sparse attention mask,
|
|
||||||
then applies block sparse attention.
|
|
||||||
"""
|
|
||||||
|
|
||||||
supports_prefill = True
|
|
||||||
supports_decode = True
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
stride: int = 8,
|
|
||||||
threshold: float = 0.9,
|
|
||||||
block_size: int = 128,
|
|
||||||
chunk_size: int = 16384,
|
|
||||||
use_triton: bool = True,
|
|
||||||
):
|
|
||||||
self.stride = stride
|
|
||||||
self.threshold = threshold
|
|
||||||
self.block_size = block_size
|
|
||||||
self.chunk_size = chunk_size
|
|
||||||
self.use_triton = use_triton
|
|
||||||
|
|
||||||
def estimate(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
) -> Optional[torch.Tensor]:
|
|
||||||
"""
|
|
||||||
XAttention estimation (xattn_estimate).
|
|
||||||
|
|
||||||
Uses chunked GEMM + softmax to estimate block-level importance,
|
|
||||||
then selects important blocks based on threshold.
|
|
||||||
|
|
||||||
对应 COMPASS 的 xattn_estimate() 函数:
|
|
||||||
1. Pad inputs to chunk_size multiples
|
|
||||||
2. Reshape with stride
|
|
||||||
3. Compute QK^T in chunks (Triton)
|
|
||||||
4. Block-wise softmax + aggregation
|
|
||||||
5. Threshold-based selection
|
|
||||||
|
|
||||||
Args:
|
|
||||||
q: [seq_len, num_heads, head_dim]
|
|
||||||
k: [seq_len, num_kv_heads, head_dim]
|
|
||||||
layer_id: transformer layer index
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
sparse_mask: [num_heads, q_blocks, k_blocks] boolean mask
|
|
||||||
or None (fallback to full attention)
|
|
||||||
"""
|
|
||||||
# TODO: 实现真正的 xattn_estimate
|
|
||||||
# 当前返回 None 使用 full attention
|
|
||||||
return None
|
|
||||||
|
|
||||||
def compute_prefill(
|
|
||||||
self,
|
|
||||||
q: torch.Tensor,
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
layer_id: int,
|
|
||||||
softmax_scale: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute XAttention sparse prefill.
|
|
||||||
|
|
||||||
Flow:
|
|
||||||
1. Call estimate() to get sparse mask
|
|
||||||
2. If mask is None, use full attention
|
|
||||||
3. Otherwise, apply block sparse attention with mask
|
|
||||||
"""
|
|
||||||
# Step 1: Estimate sparse mask
|
|
||||||
sparse_mask = self.estimate(q, k, layer_id)
|
|
||||||
|
|
||||||
# Step 2: Compute attention
|
|
||||||
if sparse_mask is None:
|
|
||||||
# Fallback to full attention
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
|
||||||
|
|
||||||
seq_len = q.shape[0]
|
|
||||||
cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=q.device)
|
|
||||||
|
|
||||||
return flash_attn_varlen_func(
|
|
||||||
q, k, v,
|
|
||||||
cu_seqlens_q=cu_seqlens,
|
|
||||||
cu_seqlens_k=cu_seqlens,
|
|
||||||
max_seqlen_q=seq_len,
|
|
||||||
max_seqlen_k=seq_len,
|
|
||||||
softmax_scale=softmax_scale,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Apply block sparse attention with mask
|
|
||||||
# 使用 block_sparse_attn_func(q, k, v, sparse_mask, block_size)
|
|
||||||
raise NotImplementedError("Block sparse attention not yet implemented")
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return (f"XAttentionPolicy("
|
|
||||||
f"stride={self.stride}, "
|
|
||||||
f"threshold={self.threshold}, "
|
|
||||||
f"block_size={self.block_size})")
|
|
||||||
```
|
|
||||||
|
|
||||||
### Phase 5: 更新 model_runner.py
|
|
||||||
|
|
||||||
```python
|
|
||||||
# model_runner.py - allocate_kv_cache()
|
|
||||||
|
|
||||||
# 改为总是创建 policy(包括 FULL)
|
|
||||||
from nanovllm.kvcache.sparse import create_attention_policy
|
|
||||||
self.attention_policy = create_attention_policy(config.attention_policy, **policy_kwargs)
|
|
||||||
logger.info(f"Attention policy: {self.attention_policy}")
|
|
||||||
|
|
||||||
# run_layerwise_offload_prefill() 和 run_gpu_only_prefill()
|
|
||||||
|
|
||||||
# 旧代码:
|
|
||||||
if self.sparse_prefill_policy is not None:
|
|
||||||
attn_output = self.sparse_prefill_policy.sparse_prefill_attention(q, k, v, layer_id)
|
|
||||||
else:
|
|
||||||
attn_output = flash_attn_varlen_func(...)
|
|
||||||
|
|
||||||
# 新代码:
|
|
||||||
attn_output = self.attention_policy.compute_prefill(
|
|
||||||
q, k, v, layer_id, softmax_scale=layer.self_attn.attn.scale
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
## Method Mapping
|
|
||||||
|
|
||||||
| 旧方法 | 新方法 | 说明 |
|
|
||||||
|--------|--------|------|
|
|
||||||
| `select_blocks()` | `estimate()` | 计算 sparse mask(对应 xattn_estimate) |
|
|
||||||
| `sparse_prefill_attention()` | `compute_prefill()` | Prefill attention |
|
|
||||||
| (无) | `compute_decode()` | Decode attention(默认实现) |
|
|
||||||
| `on_prefill_offload()` | (移除) | Offload 在 model_runner 处理 |
|
|
||||||
|
|
||||||
## Files to Modify
|
|
||||||
|
|
||||||
| File | Changes |
|
|
||||||
|------|---------|
|
|
||||||
| `nanovllm/kvcache/sparse/policy.py` | 新接口:estimate, compute_prefill, compute_decode |
|
|
||||||
| `nanovllm/kvcache/sparse/full_policy.py` | 实现 compute_prefill(), estimate() 返回 None |
|
|
||||||
| `nanovllm/kvcache/sparse/xattn.py` | estimate() 对应 xattn_estimate, compute_prefill() |
|
|
||||||
| `nanovllm/kvcache/sparse/__init__.py` | 更新工厂函数 |
|
|
||||||
| `nanovllm/engine/model_runner.py` | 统一调用 attention_policy.compute_prefill() |
|
|
||||||
| `nanovllm/config.py` | 可选:重命名配置项 |
|
|
||||||
|
|
||||||
## Decisions Made
|
|
||||||
|
|
||||||
1. **方法命名**: `compute_prefill` / `compute_decode` 对应 chunked 版本的命名风格
|
|
||||||
2. **estimate 方法**: 替代 `select_blocks`,返回 sparse mask 而不是 block IDs
|
|
||||||
3. **XAttention**: `estimate()` 对应 COMPASS 的 `xattn_estimate()`
|
|
||||||
4. **Full Policy**: `estimate()` 返回 None 表示使用完整 attention
|
|
||||||
5. **Decode 默认实现**: 基类提供默认的 FlashAttention 实现
|
|
||||||
|
|
||||||
## Errors Encountered
|
|
||||||
- (无)
|
|
||||||
|
|
||||||
## Status
|
|
||||||
**Currently in Phase 1** - 完成分析和接口设计,等待用户确认后进入 Phase 2
|
|
||||||
@@ -32,14 +32,11 @@ def run_needle_test(
|
|||||||
enable_cpu_offload: bool = False,
|
enable_cpu_offload: bool = False,
|
||||||
enable_quest: bool = False,
|
enable_quest: bool = False,
|
||||||
enable_minference: bool = False,
|
enable_minference: bool = False,
|
||||||
enable_xattn: bool = False,
|
|
||||||
sparse_topk: int = 8,
|
sparse_topk: int = 8,
|
||||||
sparse_threshold: int = 4,
|
sparse_threshold: int = 4,
|
||||||
minference_budget: float = 0.3,
|
minference_budget: float = 0.3,
|
||||||
minference_vertical: int = 1000,
|
minference_vertical: int = 1000,
|
||||||
minference_slash: int = 6096,
|
minference_slash: int = 6096,
|
||||||
xattn_threshold: float = 0.9,
|
|
||||||
xattn_use_bsa: bool = True,
|
|
||||||
gpu_utilization: float = 0.9,
|
gpu_utilization: float = 0.9,
|
||||||
enforce_eager: bool = True,
|
enforce_eager: bool = True,
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
@@ -59,14 +56,11 @@ def run_needle_test(
|
|||||||
enable_cpu_offload: Enable CPU offload mode
|
enable_cpu_offload: Enable CPU offload mode
|
||||||
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
enable_quest: Enable Quest sparse attention (decode-only Top-K)
|
||||||
enable_minference: Enable MInference sparse prefill (GPU-only)
|
enable_minference: Enable MInference sparse prefill (GPU-only)
|
||||||
enable_xattn: Enable XAttention sparse prefill with BSA
|
|
||||||
sparse_topk: Top-K blocks for Quest
|
sparse_topk: Top-K blocks for Quest
|
||||||
sparse_threshold: Apply sparse only when blocks > threshold
|
sparse_threshold: Apply sparse only when blocks > threshold
|
||||||
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
|
minference_budget: MInference adaptive budget (fraction of seq_len, None=fixed mode)
|
||||||
minference_vertical: Fixed vertical_size (only used when budget=None)
|
minference_vertical: Fixed vertical_size (only used when budget=None)
|
||||||
minference_slash: Fixed slash_size (only used when budget=None)
|
minference_slash: Fixed slash_size (only used when budget=None)
|
||||||
xattn_threshold: XAttention block selection threshold (0-1)
|
|
||||||
xattn_use_bsa: Use Block Sparse Attention library
|
|
||||||
gpu_utilization: GPU memory utilization fraction
|
gpu_utilization: GPU memory utilization fraction
|
||||||
verbose: Print detailed output
|
verbose: Print detailed output
|
||||||
|
|
||||||
@@ -74,9 +68,7 @@ def run_needle_test(
|
|||||||
True if test passed, False otherwise
|
True if test passed, False otherwise
|
||||||
"""
|
"""
|
||||||
# Determine sparse policy
|
# Determine sparse policy
|
||||||
if enable_xattn:
|
if enable_minference:
|
||||||
sparse_policy = SparsePolicyType.XATTN
|
|
||||||
elif enable_minference:
|
|
||||||
sparse_policy = SparsePolicyType.MINFERENCE
|
sparse_policy = SparsePolicyType.MINFERENCE
|
||||||
elif enable_quest:
|
elif enable_quest:
|
||||||
sparse_policy = SparsePolicyType.QUEST
|
sparse_policy = SparsePolicyType.QUEST
|
||||||
@@ -102,8 +94,6 @@ def run_needle_test(
|
|||||||
print(f" MInference: adaptive (budget={minference_budget})")
|
print(f" MInference: adaptive (budget={minference_budget})")
|
||||||
else:
|
else:
|
||||||
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
|
print(f" MInference: fixed (vertical={minference_vertical}, slash={minference_slash})")
|
||||||
if enable_xattn:
|
|
||||||
print(f" XAttention: threshold={xattn_threshold}, use_bsa={xattn_use_bsa}")
|
|
||||||
print(f"{'='*60}\n")
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
# 1. Initialize LLM
|
# 1. Initialize LLM
|
||||||
@@ -121,7 +111,7 @@ def run_needle_test(
|
|||||||
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
llm_kwargs["sparse_threshold_blocks"] = sparse_threshold
|
||||||
|
|
||||||
# Set sparse policy (can be used with or without offload)
|
# Set sparse policy (can be used with or without offload)
|
||||||
if enable_minference or enable_quest or enable_xattn:
|
if enable_minference or enable_quest:
|
||||||
llm_kwargs["sparse_policy"] = sparse_policy
|
llm_kwargs["sparse_policy"] = sparse_policy
|
||||||
|
|
||||||
# MInference params (works with both GPU-only and offload mode)
|
# MInference params (works with both GPU-only and offload mode)
|
||||||
@@ -130,11 +120,6 @@ def run_needle_test(
|
|||||||
llm_kwargs["minference_vertical_size"] = minference_vertical
|
llm_kwargs["minference_vertical_size"] = minference_vertical
|
||||||
llm_kwargs["minference_slash_size"] = minference_slash
|
llm_kwargs["minference_slash_size"] = minference_slash
|
||||||
|
|
||||||
# XAttention params
|
|
||||||
if enable_xattn:
|
|
||||||
llm_kwargs["xattn_threshold"] = xattn_threshold
|
|
||||||
llm_kwargs["xattn_use_bsa"] = xattn_use_bsa
|
|
||||||
|
|
||||||
llm = LLM(model_path, **llm_kwargs)
|
llm = LLM(model_path, **llm_kwargs)
|
||||||
|
|
||||||
# 2. Generate needle prompt
|
# 2. Generate needle prompt
|
||||||
@@ -239,11 +224,6 @@ if __name__ == "__main__":
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
|
help="Enable MInference sparse prefill (GPU-only, vertical+slash pattern)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--enable-xattn",
|
|
||||||
action="store_true",
|
|
||||||
help="Enable XAttention sparse prefill with Block Sparse Attention"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sparse-topk",
|
"--sparse-topk",
|
||||||
type=int,
|
type=int,
|
||||||
@@ -274,17 +254,6 @@ if __name__ == "__main__":
|
|||||||
default=6096,
|
default=6096,
|
||||||
help="Fixed slash_size (only used when budget=0)"
|
help="Fixed slash_size (only used when budget=0)"
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--xattn-threshold",
|
|
||||||
type=float,
|
|
||||||
default=0.9,
|
|
||||||
help="XAttention block selection threshold (0-1, higher=more blocks)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--xattn-no-bsa",
|
|
||||||
action="store_true",
|
|
||||||
help="Disable Block Sparse Attention (use FlashAttention fallback)"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--gpu-utilization",
|
"--gpu-utilization",
|
||||||
type=float,
|
type=float,
|
||||||
@@ -322,14 +291,11 @@ if __name__ == "__main__":
|
|||||||
enable_cpu_offload=args.enable_offload,
|
enable_cpu_offload=args.enable_offload,
|
||||||
enable_quest=args.enable_quest,
|
enable_quest=args.enable_quest,
|
||||||
enable_minference=args.enable_minference,
|
enable_minference=args.enable_minference,
|
||||||
enable_xattn=args.enable_xattn,
|
|
||||||
sparse_topk=args.sparse_topk,
|
sparse_topk=args.sparse_topk,
|
||||||
sparse_threshold=args.sparse_threshold,
|
sparse_threshold=args.sparse_threshold,
|
||||||
minference_budget=minference_budget,
|
minference_budget=minference_budget,
|
||||||
minference_vertical=args.minference_vertical,
|
minference_vertical=args.minference_vertical,
|
||||||
minference_slash=args.minference_slash,
|
minference_slash=args.minference_slash,
|
||||||
xattn_threshold=args.xattn_threshold,
|
|
||||||
xattn_use_bsa=not args.xattn_no_bsa,
|
|
||||||
gpu_utilization=args.gpu_utilization,
|
gpu_utilization=args.gpu_utilization,
|
||||||
enforce_eager=enforce_eager,
|
enforce_eager=enforce_eager,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
|||||||
841
tests/test_offload_unified.py
Normal file
841
tests/test_offload_unified.py
Normal file
@@ -0,0 +1,841 @@
|
|||||||
|
"""
|
||||||
|
OffloadedTensor 统一测试套件
|
||||||
|
|
||||||
|
本文件整合了 OffloadedTensor 的所有测试,包括:
|
||||||
|
1. 基础功能验证
|
||||||
|
2. Chunked GEMM 测试
|
||||||
|
3. 同步分析
|
||||||
|
|
||||||
|
核心组件:
|
||||||
|
- OffloadedTensor: 虚拟 GPU Tensor,支持透明 CPU/GPU 数据移动
|
||||||
|
- OffloadManager: LRU 缓存管理,支持同步/异步传输
|
||||||
|
- ChunkedOffloadLinear: 沿着 seqlen 维度分块的 Linear 层
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import weakref
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Optional, Dict, List, Tuple, Any
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Part 1: 核心组件
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class OffloadedTensor(torch.Tensor):
|
||||||
|
"""
|
||||||
|
虚拟 GPU Tensor:假装在 GPU 上,实际可能在 CPU
|
||||||
|
|
||||||
|
所有计算操作通过 __torch_dispatch__ 拦截,
|
||||||
|
在计算前自动加载数据到 GPU。
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def __new__(cls, real_tensor: torch.Tensor, manager: 'OffloadManager', tensor_id: int):
|
||||||
|
device = torch.device("cuda", torch.cuda.current_device())
|
||||||
|
ret = torch.Tensor._make_wrapper_subclass(
|
||||||
|
cls,
|
||||||
|
real_tensor.size(),
|
||||||
|
strides=real_tensor.stride(),
|
||||||
|
dtype=real_tensor.dtype,
|
||||||
|
device=device,
|
||||||
|
requires_grad=real_tensor.requires_grad
|
||||||
|
)
|
||||||
|
ret._real_tensor = real_tensor
|
||||||
|
ret._manager = weakref.ref(manager)
|
||||||
|
ret._tensor_id = tensor_id
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def __init__(self, real_tensor: torch.Tensor, manager: 'OffloadManager', tensor_id: int):
|
||||||
|
self._real_tensor = real_tensor
|
||||||
|
self._manager = weakref.ref(manager)
|
||||||
|
self._tensor_id = tensor_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
"""永远返回 CUDA device,欺骗 PyTorch 的检查"""
|
||||||
|
return torch.device("cuda", torch.cuda.current_device())
|
||||||
|
|
||||||
|
def to(self, *args, **kwargs):
|
||||||
|
"""拦截 .to() 调用"""
|
||||||
|
device = None
|
||||||
|
if args and isinstance(args[0], torch.device):
|
||||||
|
device = args[0]
|
||||||
|
elif 'device' in kwargs:
|
||||||
|
device = kwargs['device']
|
||||||
|
|
||||||
|
if device and device.type == "cuda":
|
||||||
|
return self
|
||||||
|
return super().to(*args, **kwargs)
|
||||||
|
|
||||||
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||||
|
"""拦截所有 PyTorch 操作,自动加载数据"""
|
||||||
|
kwargs = kwargs or {}
|
||||||
|
|
||||||
|
manager = self._manager()
|
||||||
|
if manager:
|
||||||
|
manager.stats['dispatch_count'] += 1
|
||||||
|
|
||||||
|
# 特殊处理:detach 返回 self
|
||||||
|
func_name = getattr(func, 'name', '')
|
||||||
|
if isinstance(func_name, str) and 'detach' in func_name.lower():
|
||||||
|
return self
|
||||||
|
|
||||||
|
# 解包 OffloadedTensor 为真实 tensor
|
||||||
|
def unwrap(t):
|
||||||
|
if isinstance(t, OffloadedTensor):
|
||||||
|
mgr = t._manager()
|
||||||
|
if mgr:
|
||||||
|
return mgr.get_gpu_tensor(t._real_tensor, t._tensor_id)
|
||||||
|
return t._real_tensor.cuda()
|
||||||
|
return t
|
||||||
|
|
||||||
|
new_args = torch.utils._pytree.tree_map(unwrap, args)
|
||||||
|
new_kwargs = torch.utils._pytree.tree_map(unwrap, kwargs)
|
||||||
|
|
||||||
|
result = func(*new_args, **new_kwargs)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class OffloadManager:
|
||||||
|
"""
|
||||||
|
管理 tensor 的卸载和预取
|
||||||
|
|
||||||
|
特性:
|
||||||
|
- LRU 缓存管理 GPU 上的张量
|
||||||
|
- 支持同步/异步传输模式
|
||||||
|
- 完整的性能统计
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
device: str = "cuda",
|
||||||
|
offload_device: str = "cpu",
|
||||||
|
max_gpu_tensors: int = 2,
|
||||||
|
non_blocking: bool = False,
|
||||||
|
):
|
||||||
|
self.device = torch.device(device)
|
||||||
|
self.offload_device = torch.device(offload_device)
|
||||||
|
self._gpu_pool: Dict[int, torch.Tensor] = {}
|
||||||
|
self._cpu_storage: Dict[int, torch.Tensor] = {}
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._tensor_id_counter = 0
|
||||||
|
self._max_gpu_tensors = max_gpu_tensors
|
||||||
|
self._access_order: List[int] = []
|
||||||
|
self.non_blocking = non_blocking
|
||||||
|
|
||||||
|
# 统计信息
|
||||||
|
self.stats = {
|
||||||
|
'load_count': 0,
|
||||||
|
'evict_count': 0,
|
||||||
|
'dispatch_count': 0,
|
||||||
|
'transfer_times_ms': [],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _next_id(self) -> int:
|
||||||
|
tid = self._tensor_id_counter
|
||||||
|
self._tensor_id_counter += 1
|
||||||
|
return tid
|
||||||
|
|
||||||
|
def wrap(self, tensor: torch.Tensor) -> OffloadedTensor:
|
||||||
|
"""包装 tensor 为虚拟 GPU tensor"""
|
||||||
|
if isinstance(tensor, OffloadedTensor):
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
tensor_id = self._next_id()
|
||||||
|
cpu_tensor = tensor.detach().to(self.offload_device)
|
||||||
|
self._cpu_storage[tensor_id] = cpu_tensor
|
||||||
|
|
||||||
|
return OffloadedTensor(cpu_tensor, self, tensor_id)
|
||||||
|
|
||||||
|
def get_gpu_tensor(self, real_tensor: torch.Tensor, tensor_id: int) -> torch.Tensor:
|
||||||
|
"""获取 GPU 上的数据(LRU 缓存)"""
|
||||||
|
with self._lock:
|
||||||
|
self.stats['load_count'] += 1
|
||||||
|
|
||||||
|
if tensor_id in self._gpu_pool:
|
||||||
|
# 已在 GPU 上,更新 LRU
|
||||||
|
if tensor_id in self._access_order:
|
||||||
|
self._access_order.remove(tensor_id)
|
||||||
|
self._access_order.append(tensor_id)
|
||||||
|
return self._gpu_pool[tensor_id]
|
||||||
|
|
||||||
|
# LRU 驱逐
|
||||||
|
while len(self._gpu_pool) >= self._max_gpu_tensors:
|
||||||
|
if self._access_order:
|
||||||
|
evict_id = self._access_order.pop(0)
|
||||||
|
if evict_id in self._gpu_pool:
|
||||||
|
del self._gpu_pool[evict_id]
|
||||||
|
self.stats['evict_count'] += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 加载到 GPU
|
||||||
|
cpu_tensor = self._cpu_storage.get(tensor_id, real_tensor)
|
||||||
|
gpu_tensor = cpu_tensor.to(self.device, non_blocking=self.non_blocking)
|
||||||
|
self._gpu_pool[tensor_id] = gpu_tensor
|
||||||
|
self._access_order.append(tensor_id)
|
||||||
|
|
||||||
|
return gpu_tensor
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""获取统计信息"""
|
||||||
|
transfer_times = self.stats['transfer_times_ms']
|
||||||
|
return {
|
||||||
|
'load_count': self.stats['load_count'],
|
||||||
|
'evict_count': self.stats['evict_count'],
|
||||||
|
'dispatch_count': self.stats['dispatch_count'],
|
||||||
|
'gpu_pool_size': len(self._gpu_pool),
|
||||||
|
'total_tensors': len(self._cpu_storage),
|
||||||
|
'total_transfer_time_ms': sum(transfer_times),
|
||||||
|
'avg_transfer_time_ms': sum(transfer_times) / len(transfer_times) if transfer_times else 0,
|
||||||
|
'transfer_times_ms': list(transfer_times),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class OffloadModuleWrapper(nn.Module):
|
||||||
|
"""包装 nn.Module,实现参数级别的卸载"""
|
||||||
|
|
||||||
|
def __init__(self, module: nn.Module, manager: OffloadManager):
|
||||||
|
super().__init__()
|
||||||
|
self._original_module = module
|
||||||
|
self._manager = manager
|
||||||
|
self._wrap_parameters(module, "")
|
||||||
|
|
||||||
|
def _wrap_parameters(self, module: nn.Module, prefix: str):
|
||||||
|
"""递归包装模块的所有参数"""
|
||||||
|
for name, param in list(module.named_parameters(recurse=False)):
|
||||||
|
param.requires_grad_(False)
|
||||||
|
wrapped = self._manager.wrap(param.data)
|
||||||
|
delattr(module, name)
|
||||||
|
setattr(module, name, wrapped)
|
||||||
|
|
||||||
|
for child_name, child in list(module.named_children()):
|
||||||
|
self._wrap_parameters(child, prefix + child_name + ".")
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self._original_module(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Part 2: 高级模块
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class ChunkedOffloadLinear(nn.Module):
|
||||||
|
"""
|
||||||
|
沿着 seqlen 维度分块的 Linear 层
|
||||||
|
|
||||||
|
将输入 [seqlen, in_features] 分成多个 chunks,每个 chunk 独立进行 GEMM 计算。
|
||||||
|
weight 使用 OffloadedTensor,按需加载到 GPU。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_features: 输入特征维度
|
||||||
|
out_features: 输出特征维度
|
||||||
|
chunk_size: 每个 chunk 的大小
|
||||||
|
max_gpu_tensors: GPU 上最多缓存的 tensor 数量
|
||||||
|
non_blocking: 是否使用异步传输
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
chunk_size: int = 4096,
|
||||||
|
max_gpu_tensors: int = 2,
|
||||||
|
non_blocking: bool = False,
|
||||||
|
bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
|
self.manager = OffloadManager(
|
||||||
|
max_gpu_tensors=max_gpu_tensors,
|
||||||
|
non_blocking=non_blocking
|
||||||
|
)
|
||||||
|
|
||||||
|
weight_tensor = torch.empty(out_features, in_features, dtype=torch.float16)
|
||||||
|
nn.init.xavier_uniform_(weight_tensor)
|
||||||
|
weight_tensor.requires_grad_(False)
|
||||||
|
|
||||||
|
self.weight = self.manager.wrap(weight_tensor)
|
||||||
|
self.bias = None
|
||||||
|
if bias:
|
||||||
|
self.bias = nn.Parameter(torch.empty(out_features))
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
seqlen = x.shape[0]
|
||||||
|
|
||||||
|
if seqlen <= self.chunk_size:
|
||||||
|
return self._compute_chunk(x)
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
for start_idx in range(0, seqlen, self.chunk_size):
|
||||||
|
end_idx = min(start_idx + self.chunk_size, seqlen)
|
||||||
|
chunk = x[start_idx:end_idx]
|
||||||
|
chunk_output = self._compute_chunk(chunk)
|
||||||
|
outputs.append(chunk_output)
|
||||||
|
|
||||||
|
return torch.cat(outputs, dim=0)
|
||||||
|
|
||||||
|
def _compute_chunk(self, chunk: torch.Tensor) -> torch.Tensor:
|
||||||
|
return torch.nn.functional.linear(chunk, self.weight, self.bias)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 辅助函数
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def calculate_memory(
|
||||||
|
seqlen: int,
|
||||||
|
in_features: int,
|
||||||
|
out_features: int,
|
||||||
|
dtype: torch.dtype = torch.float16,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""计算显存占用(MB)"""
|
||||||
|
element_size = torch.finfo(dtype).bits / 8
|
||||||
|
|
||||||
|
activation = seqlen * in_features * element_size / (1024 ** 2)
|
||||||
|
weight = in_features * out_features * element_size / (1024 ** 2)
|
||||||
|
output = seqlen * out_features * element_size / (1024 ** 2)
|
||||||
|
|
||||||
|
total = activation + weight + output
|
||||||
|
peak = max(activation, output) + weight
|
||||||
|
|
||||||
|
return {
|
||||||
|
'activation_mb': activation,
|
||||||
|
'weight_mb': weight,
|
||||||
|
'output_mb': output,
|
||||||
|
'total_mb': total,
|
||||||
|
'peak_mb': peak,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def run_benchmark(
|
||||||
|
layer: nn.Module,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
num_runs: int = 3,
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""运行性能测试"""
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = layer(input_tensor)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
# Benchmark
|
||||||
|
start_time = time.time()
|
||||||
|
for _ in range(num_runs):
|
||||||
|
with torch.no_grad():
|
||||||
|
output = layer(input_tensor)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
avg_time = elapsed / num_runs
|
||||||
|
|
||||||
|
total_elements = input_tensor.numel() + output.numel()
|
||||||
|
throughput = total_elements / avg_time / 1e6
|
||||||
|
|
||||||
|
return {
|
||||||
|
'avg_time_ms': avg_time * 1000,
|
||||||
|
'throughput_meps': throughput,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Part 3: 测试套件 - 功能测试
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def test_1_basic_offloaded_tensor():
|
||||||
|
"""测试 OffloadedTensor 基本功能"""
|
||||||
|
print("\n=== Test 1: Basic OffloadedTensor ===")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
manager = OffloadManager(max_gpu_tensors=2)
|
||||||
|
|
||||||
|
t1 = torch.randn(4, 4)
|
||||||
|
t2 = torch.randn(4, 4)
|
||||||
|
t3 = torch.randn(4, 4)
|
||||||
|
|
||||||
|
w1 = manager.wrap(t1)
|
||||||
|
w2 = manager.wrap(t2)
|
||||||
|
w3 = manager.wrap(t3)
|
||||||
|
|
||||||
|
print(f"✓ Created OffloadedTensors")
|
||||||
|
print(f" w1.device: {w1.device}")
|
||||||
|
print(f" w2.device: {w2.device}")
|
||||||
|
|
||||||
|
assert w1.device.type == "cuda"
|
||||||
|
print(f"✓ is_cuda check passed")
|
||||||
|
|
||||||
|
result = w1 + w2
|
||||||
|
print(f"✓ Addition works: {result.shape}")
|
||||||
|
|
||||||
|
stats = manager.get_stats()
|
||||||
|
print(f"✓ Manager stats: {stats}")
|
||||||
|
print("PASSED\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_2_mlp_with_offload():
|
||||||
|
"""测试 MLP 模型使用 OffloadedTensor"""
|
||||||
|
print("\n=== Test 2: MLP with OffloadedTensor ===")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
class SimpleMLP(nn.Module):
|
||||||
|
def __init__(self, hidden_size=128, intermediate_size=256):
|
||||||
|
super().__init__()
|
||||||
|
self.gate_up_proj = nn.Linear(hidden_size, 2 * intermediate_size, bias=False)
|
||||||
|
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
gate, up = self.gate_up_proj(x).chunk(2, dim=-1)
|
||||||
|
return self.down_proj(nn.functional.silu(gate) * up)
|
||||||
|
|
||||||
|
hidden_size = 128
|
||||||
|
intermediate_size = 256
|
||||||
|
batch_size, seq_len = 2, 4
|
||||||
|
|
||||||
|
input_ids = torch.randn(batch_size, seq_len, hidden_size, device="cuda")
|
||||||
|
|
||||||
|
model_original = SimpleMLP(hidden_size, intermediate_size)
|
||||||
|
model_original.to("cuda")
|
||||||
|
model_original.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
expected = model_original(input_ids)
|
||||||
|
|
||||||
|
state_dict = model_original.state_dict()
|
||||||
|
|
||||||
|
model = SimpleMLP(hidden_size, intermediate_size)
|
||||||
|
model.load_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
offloaded_model, manager = apply_offload_to_model(model, max_gpu_tensors=2)
|
||||||
|
offloaded_model.eval()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
output = offloaded_model(input_ids)
|
||||||
|
|
||||||
|
print(f"✓ Forward pass completed: {output.shape}")
|
||||||
|
|
||||||
|
stats = manager.get_stats()
|
||||||
|
print(f"✓ Offload stats: {stats}")
|
||||||
|
|
||||||
|
diff = (output - expected).abs().max().item()
|
||||||
|
print(f"✓ Output correctness: max diff = {diff:.6f}")
|
||||||
|
|
||||||
|
assert diff < 1e-5
|
||||||
|
print("PASSED\n")
|
||||||
|
|
||||||
|
|
||||||
|
def apply_offload_to_model(model: nn.Module, max_gpu_tensors: int = 2):
|
||||||
|
"""应用卸载到模型的所有参数"""
|
||||||
|
manager = OffloadManager(max_gpu_tensors=max_gpu_tensors)
|
||||||
|
wrapper = OffloadModuleWrapper(model, manager)
|
||||||
|
return wrapper, manager
|
||||||
|
|
||||||
|
|
||||||
|
def test_3_lru_eviction():
|
||||||
|
"""测试 LRU 驱逐机制"""
|
||||||
|
print("\n=== Test 3: LRU Eviction ===")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
manager = OffloadManager(max_gpu_tensors=2)
|
||||||
|
|
||||||
|
tensors = [torch.randn(2, 2) for _ in range(4)]
|
||||||
|
wrapped = [manager.wrap(t) for t in tensors]
|
||||||
|
|
||||||
|
print(f"✓ Created {len(wrapped)} OffloadedTensors")
|
||||||
|
print(f" GPU pool capacity: {manager._max_gpu_tensors}")
|
||||||
|
|
||||||
|
_ = wrapped[0] + wrapped[1]
|
||||||
|
stats = manager.get_stats()
|
||||||
|
print(f"✓ After accessing t1, t2: GPU pool = {stats['gpu_pool_size']}")
|
||||||
|
|
||||||
|
_ = wrapped[2] + wrapped[2]
|
||||||
|
stats = manager.get_stats()
|
||||||
|
print(f"✓ After accessing t3: GPU pool = {stats['gpu_pool_size']}, evicted = {stats['evict_count']}")
|
||||||
|
|
||||||
|
_ = wrapped[3] + wrapped[3]
|
||||||
|
stats = manager.get_stats()
|
||||||
|
print(f"✓ After accessing t4: GPU pool = {stats['gpu_pool_size']}, evicted = {stats['evict_count']}")
|
||||||
|
|
||||||
|
assert stats['evict_count'] >= 1
|
||||||
|
print("PASSED\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_4_correctness():
|
||||||
|
"""测试输出正确性"""
|
||||||
|
print("\n=== Test 4: Correctness Check ===")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
in_features = 512
|
||||||
|
out_features = 1024
|
||||||
|
seqlen = 4096
|
||||||
|
chunk_size = 1024
|
||||||
|
|
||||||
|
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
|
||||||
|
|
||||||
|
# 创建标准层并保存权重
|
||||||
|
linear = nn.Linear(in_features, out_features, bias=False)
|
||||||
|
linear.to("cuda", dtype=torch.float16)
|
||||||
|
linear.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
expected = linear(x)
|
||||||
|
|
||||||
|
print(f"✓ Got expected output")
|
||||||
|
|
||||||
|
# 创建 ChunkedOffloadLinear,使用相同的权重
|
||||||
|
chunked_layer = ChunkedOffloadLinear(in_features, out_features, chunk_size, max_gpu_tensors=2)
|
||||||
|
|
||||||
|
# 复制权重到 chunked_layer
|
||||||
|
with torch.no_grad():
|
||||||
|
weight_data = linear.weight.data.cpu()
|
||||||
|
chunked_layer.manager._cpu_storage[0] = weight_data
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
actual = chunked_layer(x)
|
||||||
|
|
||||||
|
print(f"✓ Got actual output")
|
||||||
|
|
||||||
|
diff = (actual - expected).abs().max().item()
|
||||||
|
print(f"✓ Max difference: {diff:.6f}")
|
||||||
|
|
||||||
|
assert diff < 1e-5
|
||||||
|
print("PASSED\n")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Part 3: 测试套件 - 性能测试
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def test_5_memory_analysis():
|
||||||
|
"""分析内存占用"""
|
||||||
|
print("\n=== Test 5: Memory Analysis ===")
|
||||||
|
|
||||||
|
in_features = 4096
|
||||||
|
out_features = 12244
|
||||||
|
chunk_size = 4096
|
||||||
|
|
||||||
|
seqlens = [4096, 16384, 65536, 131072]
|
||||||
|
|
||||||
|
print(f"\nMemory Analysis (in={in_features}, out={out_features}, chunk={chunk_size}):")
|
||||||
|
print(f"{'Seqlen':>10} | {'Activation':>12} | {'Weight':>12} | {'Output':>12} | {'Peak':>12} | {'Chunked':>12}")
|
||||||
|
print("-" * 90)
|
||||||
|
|
||||||
|
for seqlen in seqlens:
|
||||||
|
full = calculate_memory(seqlen, in_features, out_features)
|
||||||
|
chunked = calculate_memory(chunk_size, in_features, out_features)
|
||||||
|
|
||||||
|
print(f"{seqlen:>10} | "
|
||||||
|
f"{full['activation_mb']:>10.1f}MB | "
|
||||||
|
f"{full['weight_mb']:>10.1f}MB | "
|
||||||
|
f"{full['output_mb']:>10.1f}MB | "
|
||||||
|
f"{full['peak_mb']:>10.1f}MB | "
|
||||||
|
f"{chunked['peak_mb']:>10.1f}MB")
|
||||||
|
|
||||||
|
print("\n✓ Chunked offload 显存占用恒定,与序列长度无关!")
|
||||||
|
print("PASSED\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_6_long_sequence():
|
||||||
|
"""测试超长序列"""
|
||||||
|
print("\n=== Test 6: Long Sequence (128K tokens) ===")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
in_features = 4096
|
||||||
|
out_features = 12244
|
||||||
|
seqlen = 128 * 1024
|
||||||
|
chunk_size = 4096
|
||||||
|
|
||||||
|
full = calculate_memory(seqlen, in_features, out_features)
|
||||||
|
chunked = calculate_memory(chunk_size, in_features, out_features)
|
||||||
|
|
||||||
|
print(f"Memory Comparison:")
|
||||||
|
print(f" Full: {full['peak_mb']:.1f} MB")
|
||||||
|
print(f" Chunked: {chunked['peak_mb']:.1f} MB")
|
||||||
|
print(f" Savings: {(1 - chunked['peak_mb']/full['peak_mb'])*100:.1f}%")
|
||||||
|
|
||||||
|
layer = ChunkedOffloadLinear(in_features, out_features, chunk_size, max_gpu_tensors=1)
|
||||||
|
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
start = time.time()
|
||||||
|
output = layer(x)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
elapsed = (time.time() - start) * 1000
|
||||||
|
|
||||||
|
print(f"✓ Forward pass: {output.shape}")
|
||||||
|
print(f" Time: {elapsed:.1f} ms")
|
||||||
|
print(f" Throughput: {seqlen/elapsed/1e3:.1f}K tokens/sec")
|
||||||
|
|
||||||
|
stats = layer.manager.get_stats()
|
||||||
|
print(f"✓ Chunks processed: {seqlen // chunk_size}")
|
||||||
|
print(f"✓ Load count: {stats['load_count']}")
|
||||||
|
print("PASSED\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_7_performance_comparison():
|
||||||
|
"""性能对比测试"""
|
||||||
|
print("\n=== Test 7: Performance Comparison ===")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
in_features = 4096
|
||||||
|
out_features = 12244
|
||||||
|
seqlen = 16384
|
||||||
|
chunk_size = 4096
|
||||||
|
|
||||||
|
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
|
||||||
|
|
||||||
|
linear = nn.Linear(in_features, out_features, bias=False).cuda().half().eval()
|
||||||
|
standard_stats = run_benchmark(linear, x, num_runs=5)
|
||||||
|
print(f"✓ Standard Linear: {standard_stats['avg_time_ms']:.1f} ms")
|
||||||
|
|
||||||
|
chunked_layer = ChunkedOffloadLinear(in_features, out_features, chunk_size, max_gpu_tensors=1)
|
||||||
|
chunked_stats = run_benchmark(chunked_layer, x, num_runs=5)
|
||||||
|
print(f"✓ ChunkedOffloadLinear: {chunked_stats['avg_time_ms']:.1f} ms")
|
||||||
|
|
||||||
|
speedup = standard_stats['avg_time_ms'] / chunked_stats['avg_time_ms']
|
||||||
|
print(f"✓ Speedup: {speedup:.2f}x")
|
||||||
|
print("PASSED\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_8_transformers_layer():
|
||||||
|
"""测试实际 transformers 权重"""
|
||||||
|
print("\n=== Test 8: Transformers Layer Test ===")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
except ImportError:
|
||||||
|
print("transformers not installed, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
model_name,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
model.to("cuda")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to load model: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
down_proj = model.model.layers[0].mlp.down_proj
|
||||||
|
print(f"✓ Got layer: {down_proj.in_features} -> {down_proj.out_features}")
|
||||||
|
|
||||||
|
batch_size, seq_len = 1, 4
|
||||||
|
test_input = torch.randn(batch_size, seq_len, down_proj.in_features, device="cuda", dtype=torch.float16)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
normal_output = down_proj(test_input)
|
||||||
|
|
||||||
|
print(f"✓ Normal inference: {normal_output.shape}")
|
||||||
|
|
||||||
|
import copy
|
||||||
|
test_linear = nn.Linear(down_proj.in_features, down_proj.out_features, bias=False)
|
||||||
|
test_linear.load_state_dict(copy.deepcopy(down_proj.state_dict()))
|
||||||
|
test_linear.to("cuda", dtype=torch.float16)
|
||||||
|
test_linear.eval()
|
||||||
|
|
||||||
|
manager = OffloadManager(max_gpu_tensors=2)
|
||||||
|
offloaded_layer = OffloadModuleWrapper(test_linear, manager)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
offload_output = offloaded_layer(test_input)
|
||||||
|
|
||||||
|
print(f"✓ Offload inference: {offload_output.shape}")
|
||||||
|
|
||||||
|
stats = manager.get_stats()
|
||||||
|
print(f"✓ Stats: {stats}")
|
||||||
|
|
||||||
|
diff = (offload_output - normal_output).abs().max().item()
|
||||||
|
print(f"✓ Max diff: {diff:.6f}")
|
||||||
|
|
||||||
|
assert diff < 1e-5
|
||||||
|
print("PASSED\n")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Part 3: 测试套件 - 同步分析
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def test_9_sync_behavior_analysis():
|
||||||
|
"""分析同步传输 vs 异步传输"""
|
||||||
|
print("\n=== Test 9: Sync Behavior Analysis ===")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
in_features = 4096
|
||||||
|
out_features = 12244
|
||||||
|
seqlen = 16384
|
||||||
|
chunk_size = 4096
|
||||||
|
|
||||||
|
print(f"Config: in={in_features}, out={out_features}, seqlen={seqlen}, chunk={chunk_size}")
|
||||||
|
print(f"Num chunks: {seqlen // chunk_size}")
|
||||||
|
|
||||||
|
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
|
||||||
|
|
||||||
|
# 同步版本
|
||||||
|
print(f"\n--- 同步传输 (non_blocking=False) ---")
|
||||||
|
layer_sync = ChunkedOffloadLinear(in_features, out_features, chunk_size, non_blocking=False)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
start = time.time()
|
||||||
|
_ = layer_sync(x)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
sync_time_ms = (time.time() - start) * 1000
|
||||||
|
|
||||||
|
stats_sync = layer_sync.manager.get_stats()
|
||||||
|
print(f"总时间: {sync_time_ms:.2f} ms")
|
||||||
|
print(f"传输时间: {stats_sync['total_transfer_time_ms']:.2f} ms")
|
||||||
|
print(f"计算时间: {sync_time_ms - stats_sync['total_transfer_time_ms']:.2f} ms")
|
||||||
|
print(f"加载次数: {stats_sync['load_count']}")
|
||||||
|
|
||||||
|
# 异步版本
|
||||||
|
print(f"\n--- 异步传输 (non_blocking=True) ---")
|
||||||
|
layer_async = ChunkedOffloadLinear(in_features, out_features, chunk_size, non_blocking=True)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
start = time.time()
|
||||||
|
_ = layer_async(x)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
async_time_ms = (time.time() - start) * 1000
|
||||||
|
|
||||||
|
stats_async = layer_async.manager.get_stats()
|
||||||
|
print(f"总时间: {async_time_ms:.2f} ms")
|
||||||
|
print(f"传输时间: {stats_async['total_transfer_time_ms']:.2f} ms")
|
||||||
|
print(f"计算时间: {async_time_ms - stats_async['total_transfer_time_ms']:.2f} ms")
|
||||||
|
print(f"加载次数: {stats_async['load_count']}")
|
||||||
|
|
||||||
|
# 对比
|
||||||
|
print(f"\n--- 对比 ---")
|
||||||
|
print(f"总加速比: {sync_time_ms / async_time_ms:.2f}x")
|
||||||
|
|
||||||
|
if stats_async['total_transfer_time_ms'] > 0:
|
||||||
|
print(f"传输加速比: {stats_sync['total_transfer_time_ms'] / stats_async['total_transfer_time_ms']:.2f}x")
|
||||||
|
|
||||||
|
print("\n关键发现:")
|
||||||
|
print(f" 1. 同步传输阻塞 CPU 线程")
|
||||||
|
print(f" 2. 异步传输可提高吞吐量")
|
||||||
|
print(f" 3. 首次运行包含 JIT 编译开销")
|
||||||
|
print("PASSED\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_10_profiler_analysis():
|
||||||
|
"""使用 Profiler 分析内核执行"""
|
||||||
|
print("\n=== Test 10: Profiler Analysis ===")
|
||||||
|
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("CUDA not available, skipping")
|
||||||
|
return
|
||||||
|
|
||||||
|
in_features = 4096
|
||||||
|
out_features = 12244
|
||||||
|
seqlen = 16384
|
||||||
|
chunk_size = 4096
|
||||||
|
|
||||||
|
layer = ChunkedOffloadLinear(in_features, out_features, chunk_size)
|
||||||
|
x = torch.randn(seqlen, in_features, device="cuda", dtype=torch.float16)
|
||||||
|
|
||||||
|
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p:
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = layer(x)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
|
kernel_counts = {}
|
||||||
|
for event in p.key_averages():
|
||||||
|
if event.device_type == torch.profiler.DeviceType.CUDA:
|
||||||
|
name = event.key
|
||||||
|
kernel_counts[name] = kernel_counts.get(name, 0) + 1
|
||||||
|
|
||||||
|
print(f"内核调用统计:")
|
||||||
|
print(f"{'内核类型':<50} {'调用次数':<10}")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
for name, count in sorted(kernel_counts.items(), key=lambda x: -x[1])[:15]:
|
||||||
|
name_short = name[:48]
|
||||||
|
print(f"{name_short:<50} {count:<10}")
|
||||||
|
|
||||||
|
memcpy_count = sum(count for name, count in kernel_counts.items() if 'memcpy' in name.lower())
|
||||||
|
print(f"\n分析:")
|
||||||
|
print(f" - 总共 {len(kernel_counts)} 种不同的 CUDA 内核")
|
||||||
|
print(f" - 总调用次数: {sum(kernel_counts.values())}")
|
||||||
|
print(f" - 内存拷贝: {memcpy_count} 次")
|
||||||
|
print("PASSED\n")
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# 主测试入口
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""运行所有测试"""
|
||||||
|
print("=" * 70)
|
||||||
|
print("OffloadedTensor 统一测试套件")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
# 功能测试
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("功能测试 (Tests 1-4)")
|
||||||
|
print("=" * 70)
|
||||||
|
test_1_basic_offloaded_tensor()
|
||||||
|
test_2_mlp_with_offload()
|
||||||
|
test_3_lru_eviction()
|
||||||
|
test_4_correctness()
|
||||||
|
|
||||||
|
# 性能测试
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("性能测试 (Tests 5-8)")
|
||||||
|
print("=" * 70)
|
||||||
|
test_5_memory_analysis()
|
||||||
|
test_6_long_sequence()
|
||||||
|
test_7_performance_comparison()
|
||||||
|
test_8_transformers_layer()
|
||||||
|
|
||||||
|
# 同步分析
|
||||||
|
print("\n" + "=" * 70)
|
||||||
|
print("同步分析 (Tests 9-10)")
|
||||||
|
print("=" * 70)
|
||||||
|
test_9_sync_behavior_analysis()
|
||||||
|
test_10_profiler_analysis()
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print("所有测试完成!")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,244 +0,0 @@
|
|||||||
"""
|
|
||||||
Test: Compare xattn_estimate vs xattn_estimate_chunked
|
|
||||||
|
|
||||||
Verify that chunked estimation with EXTERNAL chunking produces the same mask
|
|
||||||
as standard estimation. This ensures the chunked version can be used in
|
|
||||||
chunked prefill scenarios without accuracy loss.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
CUDA_VISIBLE_DEVICES=0 PYTHONPATH=/home/zijie/Code/nano-vllm:$PYTHONPATH \
|
|
||||||
python tests/test_xattn_estimate_chunked.py
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
import torch
|
|
||||||
from nanovllm.ops.xattn import xattn_estimate, xattn_estimate_chunked
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Configuration
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
# Configuration for xattn_estimate_chunked consistency test.
|
|
||||||
# Key requirements for 100% match:
|
|
||||||
# 1. Use matching chunk_size for both standard and chunked versions
|
|
||||||
# 2. Use same random seed for reproducibility
|
|
||||||
# Note: Tiny differences (~0.000001) may occur at boundary cases due to
|
|
||||||
# floating point precision in cumulative sum calculations.
|
|
||||||
BLOCK_SIZE = 64
|
|
||||||
STRIDE = 4
|
|
||||||
THRESHOLD = 0.9
|
|
||||||
CHUNK_SIZE = 4096 # External chunking size
|
|
||||||
|
|
||||||
# Test sequence lengths
|
|
||||||
TEST_SEQ_LENS = [4096, 8192, 16384, 32768]
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Utility Functions
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
def compare_masks(mask1, mask2, name1="standard", name2="chunked"):
|
|
||||||
"""Compare two masks and report differences."""
|
|
||||||
if mask1.shape != mask2.shape:
|
|
||||||
print(f" Shape mismatch: {name1}={mask1.shape}, {name2}={mask2.shape}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
diff = (mask1 != mask2).sum().item()
|
|
||||||
total = mask1.numel()
|
|
||||||
match_rate = (total - diff) / total * 100
|
|
||||||
|
|
||||||
print(f" Match rate: {match_rate:.4f}% ({total - diff}/{total})")
|
|
||||||
|
|
||||||
if diff > 0:
|
|
||||||
diff_indices = torch.where(mask1 != mask2)
|
|
||||||
print(f" First 5 diff positions: {list(zip(*[idx[:5].tolist() for idx in diff_indices]))}")
|
|
||||||
|
|
||||||
return diff == 0
|
|
||||||
|
|
||||||
|
|
||||||
def run_chunked_externally(query, key, block_size, stride, threshold, chunk_size):
|
|
||||||
"""
|
|
||||||
Run xattn_estimate_chunked with EXTERNAL chunking.
|
|
||||||
This simulates how chunked prefill should be used in practice.
|
|
||||||
"""
|
|
||||||
batch_size, num_heads, q_len, head_dim = query.shape
|
|
||||||
_, _, k_len, _ = key.shape
|
|
||||||
|
|
||||||
q_block_num = (q_len + block_size - 1) // block_size
|
|
||||||
k_block_num = (k_len + block_size - 1) // block_size
|
|
||||||
|
|
||||||
# If Q fits in one chunk, call directly
|
|
||||||
if q_len <= chunk_size:
|
|
||||||
return xattn_estimate_chunked(
|
|
||||||
query, key,
|
|
||||||
q_start_pos=0,
|
|
||||||
block_size=block_size,
|
|
||||||
stride=stride,
|
|
||||||
threshold=threshold,
|
|
||||||
use_triton=True,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# External chunking: split Q and call for each chunk
|
|
||||||
num_q_chunks = (q_len + chunk_size - 1) // chunk_size
|
|
||||||
print(f" External chunking: {num_q_chunks} chunks")
|
|
||||||
|
|
||||||
combined_attn_sum = torch.zeros(
|
|
||||||
batch_size, num_heads, q_block_num, k_block_num,
|
|
||||||
dtype=query.dtype, device=query.device
|
|
||||||
)
|
|
||||||
combined_mask = torch.zeros(
|
|
||||||
batch_size, num_heads, q_block_num, k_block_num,
|
|
||||||
dtype=torch.bool, device=query.device
|
|
||||||
)
|
|
||||||
|
|
||||||
q_block_offset = 0
|
|
||||||
for q_chunk_idx in range(num_q_chunks):
|
|
||||||
q_chunk_start = q_chunk_idx * chunk_size
|
|
||||||
q_chunk_end = min((q_chunk_idx + 1) * chunk_size, q_len)
|
|
||||||
|
|
||||||
q_chunk = query[:, :, q_chunk_start:q_chunk_end, :]
|
|
||||||
|
|
||||||
# For causal attention, K accumulates up to current Q position
|
|
||||||
# q_start_pos=0 means Q starts at position 0 in the full sequence
|
|
||||||
# K is [0, q_chunk_end) for causal attention
|
|
||||||
k_end = q_chunk_end
|
|
||||||
k_chunk = key[:, :, :k_end, :]
|
|
||||||
|
|
||||||
attn_sum_chunk, mask_chunk = xattn_estimate_chunked(
|
|
||||||
q_chunk, k_chunk,
|
|
||||||
q_start_pos=q_chunk_start,
|
|
||||||
block_size=block_size,
|
|
||||||
stride=stride,
|
|
||||||
threshold=threshold,
|
|
||||||
use_triton=True,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Place chunk results into combined output
|
|
||||||
chunk_q_blocks = mask_chunk.shape[2]
|
|
||||||
chunk_k_blocks = mask_chunk.shape[3]
|
|
||||||
combined_attn_sum[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = attn_sum_chunk
|
|
||||||
combined_mask[:, :, q_block_offset:q_block_offset+chunk_q_blocks, :chunk_k_blocks] = mask_chunk
|
|
||||||
q_block_offset += chunk_q_blocks
|
|
||||||
|
|
||||||
return combined_attn_sum, combined_mask
|
|
||||||
|
|
||||||
|
|
||||||
def test_single_seq_len(seq_len, num_heads=32, head_dim=128):
|
|
||||||
"""Test a single sequence length."""
|
|
||||||
print(f"\nTesting seq_len={seq_len}")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Generate random Q/K
|
|
||||||
query = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
|
||||||
key = torch.randn(1, num_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16)
|
|
||||||
|
|
||||||
# Run standard xattn_estimate
|
|
||||||
print("[1] Running standard xattn_estimate...")
|
|
||||||
try:
|
|
||||||
attn_sum_std, mask_std = xattn_estimate(
|
|
||||||
query, key,
|
|
||||||
block_size=BLOCK_SIZE,
|
|
||||||
stride=STRIDE,
|
|
||||||
threshold=THRESHOLD,
|
|
||||||
chunk_size=CHUNK_SIZE,
|
|
||||||
use_triton=True,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
density_std = mask_std.float().mean().item()
|
|
||||||
print(f" mask shape: {mask_std.shape}, density: {density_std:.4f}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ERROR: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Run chunked xattn_estimate with EXTERNAL chunking
|
|
||||||
print("[2] Running chunked xattn_estimate (external chunking)...")
|
|
||||||
try:
|
|
||||||
attn_sum_chunked, mask_chunked = run_chunked_externally(
|
|
||||||
query, key,
|
|
||||||
block_size=BLOCK_SIZE,
|
|
||||||
stride=STRIDE,
|
|
||||||
threshold=THRESHOLD,
|
|
||||||
chunk_size=CHUNK_SIZE,
|
|
||||||
)
|
|
||||||
density_chunked = mask_chunked.float().mean().item()
|
|
||||||
print(f" mask shape: {mask_chunked.shape}, density: {density_chunked:.4f}")
|
|
||||||
except Exception as e:
|
|
||||||
print(f" ERROR: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Compare results
|
|
||||||
print("[3] Comparing results...")
|
|
||||||
chunked_q_blocks = mask_chunked.shape[2]
|
|
||||||
chunked_k_blocks = mask_chunked.shape[3]
|
|
||||||
|
|
||||||
# Extract comparable region from standard mask
|
|
||||||
mask_std_comparable = mask_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
|
||||||
|
|
||||||
# Compare masks
|
|
||||||
masks_match = compare_masks(mask_std_comparable, mask_chunked, "standard", "chunked")
|
|
||||||
|
|
||||||
# Compare attn_sums
|
|
||||||
attn_sum_std_comparable = attn_sum_std[:, :, :chunked_q_blocks, :chunked_k_blocks]
|
|
||||||
if attn_sum_std_comparable.shape == attn_sum_chunked.shape:
|
|
||||||
attn_diff = (attn_sum_std_comparable - attn_sum_chunked).abs().max().item()
|
|
||||||
print(f" Attn sum max diff: {attn_diff:.6f}")
|
|
||||||
else:
|
|
||||||
print(f" Attn sum shape mismatch: std={attn_sum_std_comparable.shape}, chunked={attn_sum_chunked.shape}")
|
|
||||||
|
|
||||||
# Clean up GPU memory
|
|
||||||
del query, key, attn_sum_std, mask_std, attn_sum_chunked, mask_chunked
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
return masks_match
|
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
|
||||||
# Main Test
|
|
||||||
# ============================================================
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("XAttention Chunked vs Standard Test")
|
|
||||||
print("=" * 60)
|
|
||||||
print(f"Config: block_size={BLOCK_SIZE}, stride={STRIDE}, threshold={THRESHOLD}")
|
|
||||||
print(f"External chunk_size={CHUNK_SIZE}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Check CUDA availability
|
|
||||||
if not torch.cuda.is_available():
|
|
||||||
print("CUDA not available!")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
print(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
|
||||||
print("✓ xattn_estimate imported")
|
|
||||||
print("✓ xattn_estimate_chunked imported")
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
all_passed = True
|
|
||||||
results = []
|
|
||||||
|
|
||||||
for seq_len in TEST_SEQ_LENS:
|
|
||||||
passed = test_single_seq_len(seq_len)
|
|
||||||
chunks = (seq_len + CHUNK_SIZE - 1) // CHUNK_SIZE
|
|
||||||
results.append((seq_len, chunks, passed))
|
|
||||||
if not passed:
|
|
||||||
all_passed = False
|
|
||||||
|
|
||||||
# Summary
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("SUMMARY")
|
|
||||||
print("=" * 60)
|
|
||||||
for seq_len, chunks, passed in results:
|
|
||||||
status = "PASSED" if passed else "FAILED"
|
|
||||||
print(f" seq_len={seq_len:5d} ({chunks} chunk{'s' if chunks > 1 else ' '}): {status}")
|
|
||||||
|
|
||||||
print("=" * 60)
|
|
||||||
if all_passed:
|
|
||||||
print("ALL TESTS PASSED!")
|
|
||||||
sys.exit(0)
|
|
||||||
else:
|
|
||||||
print("SOME TESTS FAILED!")
|
|
||||||
sys.exit(1)
|
|
||||||
Reference in New Issue
Block a user