🔀 merge: integrate tzj/minference-exp (GPU-only sparse attention)
Merge GPU-only sparse attention support from tzj/minference-exp branch: **GPU-only mode additions:** - Add compute_prefill/compute_decode methods to SparsePolicy base class - Add GPU-only attention routing in attention.py - Add alloc_policy_metadata() for pre-allocating GQA buffers - Add XAttention + BSA sparse attention for GPU-only prefill - Add kvcache_manager to set_context() for policy access **bench.py enhancements:** - Add --model argument for configurable model path - Add --policy argument (full, xattn) for sparse policy selection - Add --enable-policy flag for FullAttentionPolicy routing - Add --enforce-eager option to disable CUDA graphs - Add --gpu-util option for GPU memory utilization **Documentation:** - Add gpu_only_xattn_guide.md with performance analysis - Add gpu_only_sparse_integration.md baseline document - Add gpu-vram-requirement.md rule for GPU-only mode Both CPU offload and GPU-only paths are preserved and functional. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
54
.claude/rules/gpu-vram-requirement.md
Normal file
54
.claude/rules/gpu-vram-requirement.md
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
# GPU VRAM Requirement Rule
|
||||||
|
|
||||||
|
## GPU-only 模式显存要求
|
||||||
|
|
||||||
|
**强制规则**:执行 GPU-only 代码(不启用 CPU offload)时,**必须**在 40GB 及以上显存的 GPU 上进行测试。
|
||||||
|
|
||||||
|
### 检测方法
|
||||||
|
|
||||||
|
在运行 GPU-only 测试之前,**必须**先检查 GPU 显存:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
nvidia-smi --query-gpu=index,name,memory.total --format=csv,noheader
|
||||||
|
```
|
||||||
|
|
||||||
|
### GPU 分类
|
||||||
|
|
||||||
|
| GPU 型号 | 显存 | GPU-only 测试 |
|
||||||
|
|----------|------|---------------|
|
||||||
|
| A100 40GB | 40GB | ✅ 允许 |
|
||||||
|
| A100 80GB | 80GB | ✅ 允许 |
|
||||||
|
| H100 80GB | 80GB | ✅ 允许 |
|
||||||
|
| A6000 | 48GB | ✅ 允许 |
|
||||||
|
| RTX 3090 | 24GB | ❌ **禁止**(仅 offload 模式) |
|
||||||
|
| RTX 4090 | 24GB | ❌ **禁止**(仅 offload 模式) |
|
||||||
|
|
||||||
|
### 执行流程
|
||||||
|
|
||||||
|
1. **检测 GPU 显存**(必须)
|
||||||
|
2. **显存 >= 40GB**:继续执行 GPU-only 测试
|
||||||
|
3. **显存 < 40GB**:**停止**,提示用户:
|
||||||
|
> "当前 GPU 显存为 XXX GB,不满足 GPU-only 模式的最低 40GB 要求。请使用 `--enable-offload` 参数启用 CPU offload 模式。"
|
||||||
|
|
||||||
|
### 代码示例
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在运行 GPU-only benchmark 之前
|
||||||
|
import subprocess
|
||||||
|
result = subprocess.run(
|
||||||
|
["nvidia-smi", "--query-gpu=memory.total", "--format=csv,noheader,nounits"],
|
||||||
|
capture_output=True, text=True
|
||||||
|
)
|
||||||
|
vram_mb = int(result.stdout.strip().split('\n')[0])
|
||||||
|
if vram_mb < 40000: # 40GB = 40000MB
|
||||||
|
raise RuntimeError(f"GPU VRAM ({vram_mb}MB) < 40GB. Use --enable-offload for this GPU.")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 适用范围
|
||||||
|
|
||||||
|
| 脚本 | 适用此规则 |
|
||||||
|
|------|-----------|
|
||||||
|
| `bench.py` | ✅ 必须检查显存 |
|
||||||
|
| `bench_offload.py` | ❌ 不适用(始终使用 offload) |
|
||||||
|
| `tests/test_*.py --enable-offload` | ❌ 不适用 |
|
||||||
|
| `tests/test_*.py` (无 offload) | ✅ 必须检查显存 |
|
||||||
@@ -1,5 +1,39 @@
|
|||||||
# Sparse Policy 代码规范
|
# Sparse Policy 代码规范
|
||||||
|
|
||||||
|
## Policy 不能为 None (CRITICAL)
|
||||||
|
|
||||||
|
**强制规则**: `sparse_policy` 参数**永远不能为 None**,必须至少为 `FullAttentionPolicy`。
|
||||||
|
|
||||||
|
```python
|
||||||
|
# ❌ 错误:允许 None
|
||||||
|
sparse_policy = getattr(config, 'sparse_policy', None)
|
||||||
|
|
||||||
|
# ✅ 正确:显式处理 None,默认使用 FULL
|
||||||
|
sparse_policy_type = getattr(config, 'sparse_policy', None)
|
||||||
|
if sparse_policy_type is None:
|
||||||
|
sparse_policy_type = SparsePolicyType.FULL
|
||||||
|
```
|
||||||
|
|
||||||
|
**原因**:
|
||||||
|
1. 统一的 API:所有代码路径都通过 policy 进行 attention 计算
|
||||||
|
2. 避免空指针:消除 `policy.xxx` 调用时的 None 检查
|
||||||
|
3. 简化逻辑:不需要 `if policy is not None` 的分支
|
||||||
|
|
||||||
|
**唯一例外:Warmup 阶段**
|
||||||
|
|
||||||
|
在 `model_runner.warmup_model()` 期间,kvcache_manager 还未分配。此时 `attention.py` 使用 flash_attn fallback:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# attention.py 中的 warmup 处理
|
||||||
|
if context.kvcache_manager is None:
|
||||||
|
# Warmup phase: use flash_attn directly
|
||||||
|
return flash_attn_varlen_func(...) if context.is_prefill else flash_attn_with_kvcache(...)
|
||||||
|
```
|
||||||
|
|
||||||
|
这是唯一允许 kvcache_manager 为 None 的情况。正式推理时,policy 必须存在。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## 基类要求 (MANDATORY)
|
## 基类要求 (MANDATORY)
|
||||||
|
|
||||||
每个 `SparsePolicy` 子类 **必须** 遵守以下要求:
|
每个 `SparsePolicy` 子类 **必须** 遵守以下要求:
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ Nano-vLLM is a lightweight vLLM implementation (~1,200 lines) for fast offline L
|
|||||||
| [`docs/cpu_scheduling_latency_analysis.md`](docs/cpu_scheduling_latency_analysis.md) | ⚡ PERF: CPU 调度延迟分析,kernel 间隙来源,GPU 利用率优化方向 |
|
| [`docs/cpu_scheduling_latency_analysis.md`](docs/cpu_scheduling_latency_analysis.md) | ⚡ PERF: CPU 调度延迟分析,kernel 间隙来源,GPU 利用率优化方向 |
|
||||||
| [`docs/bench_offload_results.md`](docs/bench_offload_results.md) | 📊 BENCH: CPU offload 性能测试结果,Full vs XAttention 对比 (32K/128K) |
|
| [`docs/bench_offload_results.md`](docs/bench_offload_results.md) | 📊 BENCH: CPU offload 性能测试结果,Full vs XAttention 对比 (32K/128K) |
|
||||||
| [`docs/cpu_offload_optimization_strategies.md`](docs/cpu_offload_optimization_strategies.md) | 🚀 OPT: CPU offload 优化策略:chunk size、CUDA Graph、前沿研究(InfiniGen/ShadowKV) |
|
| [`docs/cpu_offload_optimization_strategies.md`](docs/cpu_offload_optimization_strategies.md) | 🚀 OPT: CPU offload 优化策略:chunk size、CUDA Graph、前沿研究(InfiniGen/ShadowKV) |
|
||||||
|
| [`docs/gpu_only_xattn_guide.md`](docs/gpu_only_xattn_guide.md) | 🚀 GPU-Only XAttention: 内存预分配、性能分析 (32K +15%, 64K +41%)、CUDA Graph 限制 |
|
||||||
|
|
||||||
## Rules Index
|
## Rules Index
|
||||||
|
|
||||||
|
|||||||
29
bench.py
29
bench.py
@@ -40,24 +40,49 @@ def bench_prefill(llm, num_seqs, input_len):
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
import argparse
|
import argparse
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Benchmark nanovllm GPU performance")
|
parser = argparse.ArgumentParser(description="Benchmark nanovllm GPU performance")
|
||||||
|
parser.add_argument("--model", type=str, default="~/models/Llama-3.1-8B-Instruct",
|
||||||
|
help="Model path (default: ~/models/Llama-3.1-8B-Instruct)")
|
||||||
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
parser.add_argument("--input-len", type=int, default=None, help="Input length in tokens")
|
||||||
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
parser.add_argument("--output-len", type=int, default=64, help="Output length for decode benchmark (default: 64)")
|
||||||
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
parser.add_argument("--max-len", type=int, default=32*1024, help="Max model length (default: 32K)")
|
||||||
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
parser.add_argument("--bench-decode", action="store_true", help="Run decode benchmark (default: prefill only)")
|
||||||
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
parser.add_argument("--bench-all", action="store_true", help="Run both prefill and decode benchmarks")
|
||||||
|
# Sparse policy option (GPU-only mode now supports policy routing)
|
||||||
|
parser.add_argument("--policy", type=str, default=None,
|
||||||
|
choices=["full", "xattn"],
|
||||||
|
help="Sparse policy: full (FullAttention), xattn (XAttention+BSA)")
|
||||||
|
parser.add_argument("--enable-policy", action="store_true",
|
||||||
|
help="Enable sparse policy routing (FullAttentionPolicy by default)")
|
||||||
|
parser.add_argument("--gpu-util", type=float, default=0.9,
|
||||||
|
help="GPU memory utilization (default: 0.9)")
|
||||||
|
parser.add_argument("--enforce-eager", action="store_true",
|
||||||
|
help="Disable CUDA graphs (default: False)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
path = os.path.expanduser("~/models/Qwen3-4B-Instruct-2507/")
|
path = os.path.expanduser(args.model)
|
||||||
max_len = args.max_len
|
max_len = args.max_len
|
||||||
|
|
||||||
|
# Configure sparse policy
|
||||||
|
if args.policy == "xattn":
|
||||||
|
sparse_policy = SparsePolicyType.XATTN_BSA
|
||||||
|
print(f"\n[nanovllm GPU + XAttention BSA] max_len={max_len}")
|
||||||
|
elif args.policy == "full" or args.enable_policy:
|
||||||
|
sparse_policy = SparsePolicyType.FULL
|
||||||
|
print(f"\n[nanovllm GPU + Policy] sparse_policy=FULL, max_len={max_len}")
|
||||||
|
else:
|
||||||
|
sparse_policy = None
|
||||||
print(f"\n[nanovllm GPU] max_len={max_len}")
|
print(f"\n[nanovllm GPU] max_len={max_len}")
|
||||||
|
|
||||||
llm = LLM(
|
llm = LLM(
|
||||||
path,
|
path,
|
||||||
enforce_eager=False,
|
enforce_eager=args.enforce_eager,
|
||||||
max_model_len=max_len,
|
max_model_len=max_len,
|
||||||
max_num_batched_tokens=max_len,
|
max_num_batched_tokens=max_len,
|
||||||
|
sparse_policy=sparse_policy,
|
||||||
|
gpu_memory_utilization=args.gpu_util,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Warmup
|
# Warmup
|
||||||
|
|||||||
77
docs/gpu_only_sparse_integration.md
Normal file
77
docs/gpu_only_sparse_integration.md
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
# GPU-only Sparse Policy 整合
|
||||||
|
|
||||||
|
本文档记录将 sparse attention 策略整合到 GPU-only 模式的过程和性能对比。
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
当前 sparse policy(Quest、XAttention)仅在 CPU offload 路径中实现。目标是将其扩展到 GPU-only 模式,以提升长上下文场景下的性能。
|
||||||
|
|
||||||
|
## 基准性能(优化前)
|
||||||
|
|
||||||
|
**测试环境**:
|
||||||
|
- GPU: NVIDIA A100-SXM4-80GB
|
||||||
|
- 模型: Llama-3.1-8B-Instruct
|
||||||
|
- 上下文长度: 32K tokens
|
||||||
|
- 日期: 2026-01-27
|
||||||
|
|
||||||
|
### Prefill Benchmark (32K context)
|
||||||
|
|
||||||
|
| 模式 | Throughput | Time | KV Cache 分配 |
|
||||||
|
|------|------------|------|---------------|
|
||||||
|
| **GPU-only (Full Attention)** | 4869.67 tok/s | 6.73s | 438 blocks (56GB GPU) |
|
||||||
|
| CPU Offload (Full Attention) | 1500.29 tok/s | 21.84s | 4 blocks GPU + 32 blocks CPU |
|
||||||
|
|
||||||
|
**性能比**: GPU-only 比 CPU Offload 快 **3.2x**
|
||||||
|
|
||||||
|
### 配置详情
|
||||||
|
|
||||||
|
**GPU-only 模式**:
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python bench.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--max-len 32768
|
||||||
|
```
|
||||||
|
|
||||||
|
**CPU Offload 模式**:
|
||||||
|
```bash
|
||||||
|
CUDA_VISIBLE_DEVICES=0 python bench_offload.py \
|
||||||
|
--model ~/models/Llama-3.1-8B-Instruct \
|
||||||
|
--max-len 32768
|
||||||
|
```
|
||||||
|
|
||||||
|
### KV Cache 配置
|
||||||
|
|
||||||
|
| 参数 | GPU-only | CPU Offload |
|
||||||
|
|------|----------|-------------|
|
||||||
|
| block_size | 1024 tokens | 1024 tokens |
|
||||||
|
| per-token KV | 128 KB | 128 KB |
|
||||||
|
| per-block KV | 128 MB | 128 MB |
|
||||||
|
| GPU blocks | 438 | 4 |
|
||||||
|
| CPU blocks | 0 | 32 |
|
||||||
|
| Total memory | 56 GB | 4.6 GB |
|
||||||
|
|
||||||
|
## 目标
|
||||||
|
|
||||||
|
将以下 sparse policy 整合到 GPU-only 模式:
|
||||||
|
|
||||||
|
| Policy | 阶段 | 描述 |
|
||||||
|
|--------|------|------|
|
||||||
|
| Quest | Decode | Top-K block selection based on query-key scores |
|
||||||
|
| XAttention BSA | Prefill | Block sparse attention with cumulative threshold |
|
||||||
|
|
||||||
|
## 实现进度
|
||||||
|
|
||||||
|
- [ ] 分析现有 sparse policy 代码结构
|
||||||
|
- [ ] 设计 GPU-only sparse policy 接口
|
||||||
|
- [ ] 实现 GPU-only Quest decode
|
||||||
|
- [ ] 实现 GPU-only XAttention prefill
|
||||||
|
- [ ] 性能测试和对比
|
||||||
|
|
||||||
|
## 优化后性能
|
||||||
|
|
||||||
|
*待测试*
|
||||||
|
|
||||||
|
| 模式 | Throughput | Speedup vs Full |
|
||||||
|
|------|------------|-----------------|
|
||||||
|
| GPU-only + Quest (decode) | TBD | TBD |
|
||||||
|
| GPU-only + XAttn (prefill) | TBD | TBD |
|
||||||
296
docs/gpu_only_xattn_guide.md
Normal file
296
docs/gpu_only_xattn_guide.md
Normal file
@@ -0,0 +1,296 @@
|
|||||||
|
# GPU-Only XAttention 指南
|
||||||
|
|
||||||
|
本文档介绍 GPU-only 模式下 XAttention BSA 的实现、内存优化和性能分析。
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
GPU-only 模式下,所有 KV cache 存储在 GPU 上,无需 CPU offload。XAttention 通过稀疏注意力加速 prefill 阶段。
|
||||||
|
|
||||||
|
### 执行路径对比
|
||||||
|
|
||||||
|
| 模式 | Prefill 方法 | Decode 方法 | KV 存储 |
|
||||||
|
|------|-------------|-------------|---------|
|
||||||
|
| GPU-only Full | `compute_prefill()` | `compute_decode()` | GPU |
|
||||||
|
| GPU-only XAttn | `compute_prefill()` | `compute_decode()` | GPU |
|
||||||
|
| CPU Offload | `compute_chunked_prefill()` | `compute_chunked_decode()` | CPU + GPU |
|
||||||
|
|
||||||
|
## 架构设计
|
||||||
|
|
||||||
|
### SparsePolicy 接口
|
||||||
|
|
||||||
|
```python
|
||||||
|
class SparsePolicy:
|
||||||
|
# GPU-only 方法
|
||||||
|
def compute_prefill(self, q, k, v, ...) -> Tensor
|
||||||
|
def compute_decode(self, q, k_cache, v_cache, ...) -> Tensor
|
||||||
|
|
||||||
|
# CPU Offload 方法
|
||||||
|
def compute_chunked_prefill(self, q, k, v, ...) -> Tensor
|
||||||
|
def compute_chunked_decode(self, q, ...) -> Tensor
|
||||||
|
|
||||||
|
# 初始化方法
|
||||||
|
def initialize(self, num_layers, ...) -> None # CPU offload metadata
|
||||||
|
def alloc_policy_metadata(self, num_heads, ...) -> None # GPU-only buffers
|
||||||
|
```
|
||||||
|
|
||||||
|
### XAttentionBSAPolicy 实现
|
||||||
|
|
||||||
|
```
|
||||||
|
GPU-only Prefill 流程:
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ 1. GQA 扩展 (使用预分配 buffer) │
|
||||||
|
│ K: [seq, kv_heads, dim] → K_exp: [1, heads, seq, dim] │
|
||||||
|
│ │
|
||||||
|
│ 2. XAttention 估计 │
|
||||||
|
│ flat_group_gemm_fuse_reshape_kernel (Q@K^T) │
|
||||||
|
│ softmax_fuse_block_sum_kernel (block 重要性) │
|
||||||
|
│ → sparse mask │
|
||||||
|
│ │
|
||||||
|
│ 3. BSA 稀疏注意力 │
|
||||||
|
│ flash_fwd_block_kernel (只计算选中的 blocks) │
|
||||||
|
│ → output │
|
||||||
|
└─────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## 内存预分配
|
||||||
|
|
||||||
|
### 问题背景
|
||||||
|
|
||||||
|
XAttention 的 `compute_prefill()` 需要 GQA 扩展:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 之前: 动态分配 (~2GB for 64K)
|
||||||
|
K_exp = K.repeat_interleave(num_groups, dim=1) # 分配 1
|
||||||
|
k_bsa = k.repeat_interleave(num_groups, dim=1) # 分配 2 (重复!)
|
||||||
|
```
|
||||||
|
|
||||||
|
每次 prefill 都动态分配,导致:
|
||||||
|
- 内存碎片
|
||||||
|
- 分配延迟
|
||||||
|
- 可能 OOM
|
||||||
|
|
||||||
|
### 解决方案: alloc_policy_metadata()
|
||||||
|
|
||||||
|
在框架初始化时预分配 buffer:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class XAttentionBSAPolicy(SparsePolicy):
|
||||||
|
def alloc_policy_metadata(self, num_heads, num_kv_heads, head_dim,
|
||||||
|
max_seq_len, dtype, device):
|
||||||
|
# 预分配 GQA 扩展 buffer
|
||||||
|
shape = (1, num_heads, max_seq_len, head_dim)
|
||||||
|
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def compute_prefill(self, q, k, v, ...):
|
||||||
|
seq_len = k.shape[0]
|
||||||
|
# 使用预分配 buffer 的 slice
|
||||||
|
K_exp = self._k_expanded[:, :, :seq_len, :]
|
||||||
|
# 原地 GQA 扩展
|
||||||
|
K_exp.view(...).copy_(K.unsqueeze(2).expand(...))
|
||||||
|
# 复用同一 buffer 给 BSA
|
||||||
|
k_bsa = K_exp.squeeze(0).transpose(0, 1)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 内存使用
|
||||||
|
|
||||||
|
| 序列长度 | 预分配大小 | 说明 |
|
||||||
|
|---------|-----------|------|
|
||||||
|
| 32K | 512 MB | `2 * 32 * 32768 * 128 * 2 bytes` |
|
||||||
|
| 64K | 1024 MB | `2 * 32 * 65536 * 128 * 2 bytes` |
|
||||||
|
|
||||||
|
优化效果:
|
||||||
|
- 之前: ~2GB 动态分配 (xattn_estimate + BSA 各一次)
|
||||||
|
- 之后: ~1GB 预分配 (复用同一 buffer)
|
||||||
|
|
||||||
|
### 框架集成
|
||||||
|
|
||||||
|
```python
|
||||||
|
# model_runner.py - allocate_kv_cache()
|
||||||
|
def allocate_kv_cache(self):
|
||||||
|
# ... KV cache 分配 ...
|
||||||
|
|
||||||
|
# GPU-only 模式: 预分配 policy buffers
|
||||||
|
if not config.enable_cpu_offload:
|
||||||
|
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
max_seq_len=config.max_model_len,
|
||||||
|
dtype=dtype,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 性能分析
|
||||||
|
|
||||||
|
### 32K Prefill 性能
|
||||||
|
|
||||||
|
| Policy | Throughput | 相对提升 |
|
||||||
|
|--------|------------|----------|
|
||||||
|
| Baseline | 4880 tok/s | - |
|
||||||
|
| Full | 4892 tok/s | +0.2% |
|
||||||
|
| **XAttention** | **5602 tok/s** | **+15%** |
|
||||||
|
|
||||||
|
### 64K Prefill 性能
|
||||||
|
|
||||||
|
| Policy | Throughput | 相对提升 |
|
||||||
|
|--------|------------|----------|
|
||||||
|
| Baseline | 3386 tok/s | - |
|
||||||
|
| Full | 3355 tok/s | -0.9% |
|
||||||
|
| **XAttention** | **4775 tok/s** | **+41%** |
|
||||||
|
|
||||||
|
### Kernel 时间分解 (32K)
|
||||||
|
|
||||||
|
**XAttention:**
|
||||||
|
```
|
||||||
|
FFN GEMM: 3219 ms (54%)
|
||||||
|
BSA Attention: 1231 ms (21%)
|
||||||
|
XAttn Estimation: 415 ms (7%)
|
||||||
|
Other: 1020 ms (18%)
|
||||||
|
─────────────────────────────
|
||||||
|
Total: 5885 ms
|
||||||
|
```
|
||||||
|
|
||||||
|
**Full:**
|
||||||
|
```
|
||||||
|
FFN GEMM: 3244 ms (48%)
|
||||||
|
Dense Attention: 2861 ms (43%)
|
||||||
|
Other: 595 ms (9%)
|
||||||
|
─────────────────────────────
|
||||||
|
Total: 6700 ms
|
||||||
|
```
|
||||||
|
|
||||||
|
### 加速来源
|
||||||
|
|
||||||
|
```
|
||||||
|
Dense Attention: 2861 ms
|
||||||
|
BSA Attention: 1231 ms (节省 1630 ms, -57%)
|
||||||
|
XAttn Estimation: 415 ms (额外开销)
|
||||||
|
─────────────────────────────
|
||||||
|
净节省: 1215 ms (42% attention 时间)
|
||||||
|
```
|
||||||
|
|
||||||
|
## CUDA Graph 限制
|
||||||
|
|
||||||
|
### 为什么 Prefill 不能用 CUDA Graph
|
||||||
|
|
||||||
|
CUDA Graph 要求所有操作在 capture 时确定:
|
||||||
|
|
||||||
|
| 必须固定 | Prefill 的情况 |
|
||||||
|
|---------|---------------|
|
||||||
|
| Tensor 形状 | seq_len 可变 (1 ~ max_model_len) |
|
||||||
|
| Kernel grid | 依赖 seq_len |
|
||||||
|
| 内存地址 | 中间 tensor 大小变化 |
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 不同请求的 seq_len 不同
|
||||||
|
request_1: prefill(seq_len=1024) # grid=(8, 32, 1)
|
||||||
|
request_2: prefill(seq_len=32768) # grid=(256, 32, 1)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Decode 可以用 CUDA Graph
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Decode 每次只处理 1 token
|
||||||
|
q: [batch_size, 1, heads, dim] # 形状固定
|
||||||
|
```
|
||||||
|
|
||||||
|
nanovllm 为每个 batch_size 预先 capture 一个 graph:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def capture_cudagraph(self):
|
||||||
|
for batch_size in [1, 2, 4, 8, ...]:
|
||||||
|
with torch.cuda.graph(g):
|
||||||
|
self.run_model(dummy_input, is_prefill=False)
|
||||||
|
self.graphs[batch_size] = g
|
||||||
|
```
|
||||||
|
|
||||||
|
### Nsys Profile 结果
|
||||||
|
|
||||||
|
```
|
||||||
|
XAttention 32K Prefill:
|
||||||
|
Total kernels: 41,904
|
||||||
|
Non-graph: 41,904 (100%)
|
||||||
|
Graph: 0
|
||||||
|
|
||||||
|
Full 32K Prefill:
|
||||||
|
Total kernels: 35,308
|
||||||
|
Non-graph: 35,308 (100%)
|
||||||
|
Graph: 0
|
||||||
|
```
|
||||||
|
|
||||||
|
**两者都是 100% NON-GRAPH**,这是 prefill 的本质特性。
|
||||||
|
|
||||||
|
## Profiling 工具
|
||||||
|
|
||||||
|
### 使用 profile.sh
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# XAttention 32K
|
||||||
|
bash scripts/profile.sh --max-len 32768 --policy xattn
|
||||||
|
|
||||||
|
# Full 32K
|
||||||
|
bash scripts/profile.sh --max-len 32768 --policy full
|
||||||
|
|
||||||
|
# 64K (需要降低 gpu-util)
|
||||||
|
bash scripts/profile.sh --max-len 65536 --policy xattn --gpu-util 0.7
|
||||||
|
```
|
||||||
|
|
||||||
|
### 分析 nsys 结果
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 查看 kernel 统计
|
||||||
|
nsys stats --report cuda_gpu_kern_sum results/nsys/<file>.nsys-rep
|
||||||
|
|
||||||
|
# 用 sqlite 查询详细数据
|
||||||
|
sqlite3 results/nsys/<file>.sqlite "
|
||||||
|
SELECT
|
||||||
|
(SELECT value FROM StringIds WHERE id = shortName) as kernel,
|
||||||
|
COUNT(*) as count,
|
||||||
|
SUM(end-start)/1e6 as total_ms
|
||||||
|
FROM CUPTI_ACTIVITY_KIND_KERNEL
|
||||||
|
GROUP BY shortName
|
||||||
|
ORDER BY total_ms DESC
|
||||||
|
LIMIT 10
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用指南
|
||||||
|
|
||||||
|
### 启用 XAttention GPU-only
|
||||||
|
|
||||||
|
```python
|
||||||
|
from nanovllm import LLM
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
|
||||||
|
llm = LLM(
|
||||||
|
model_path,
|
||||||
|
max_model_len=32768,
|
||||||
|
sparse_policy=SparsePolicyType.XATTN_BSA,
|
||||||
|
gpu_memory_utilization=0.9, # 64K 时可能需要降低
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 命令行测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# bench.py
|
||||||
|
python bench.py --max-len 32768 --policy xattn
|
||||||
|
|
||||||
|
# 64K 需要降低 gpu-util
|
||||||
|
python bench.py --max-len 65536 --policy xattn --gpu-util 0.7
|
||||||
|
```
|
||||||
|
|
||||||
|
### 最佳实践
|
||||||
|
|
||||||
|
1. **32K 及以下**: 使用默认 `gpu_memory_utilization=0.9`
|
||||||
|
2. **64K**: 降低到 `gpu_memory_utilization=0.7`
|
||||||
|
3. **Decode**: XAttention 自动 fallback 到 FullAttentionPolicy
|
||||||
|
4. **Paged KV Cache**: 当 `block_tables` 存在时自动 fallback 到 flash_attn
|
||||||
|
|
||||||
|
## 相关文档
|
||||||
|
|
||||||
|
- [Sparse Policy 架构](sparse_policy_architecture.md)
|
||||||
|
- [XAttention 算法详解](xattention_algorithm_guide.md)
|
||||||
|
- [BSA 接口文档](block_sparse_attn_interface.md)
|
||||||
@@ -202,19 +202,36 @@ class ModelRunner:
|
|||||||
dtype=hf_config.torch_dtype,
|
dtype=hf_config.torch_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize sparse policy if manager has one (CPU offload mode)
|
# Initialize sparse policy if manager has one (works for both CPU offload and GPU-only modes)
|
||||||
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
|
if hasattr(self.kvcache_manager, 'sparse_policy') and self.kvcache_manager.sparse_policy is not None:
|
||||||
|
# Use CPU blocks for offload mode, GPU blocks for GPU-only mode
|
||||||
|
num_blocks_for_init = config.num_cpu_kvcache_blocks if config.enable_cpu_offload else config.num_kvcache_blocks
|
||||||
self.kvcache_manager.sparse_policy.initialize(
|
self.kvcache_manager.sparse_policy.initialize(
|
||||||
num_layers=hf_config.num_hidden_layers,
|
num_layers=hf_config.num_hidden_layers,
|
||||||
num_kv_heads=num_kv_heads,
|
num_kv_heads=num_kv_heads,
|
||||||
head_dim=head_dim,
|
head_dim=head_dim,
|
||||||
num_cpu_blocks=config.num_cpu_kvcache_blocks,
|
num_cpu_blocks=num_blocks_for_init,
|
||||||
dtype=hf_config.torch_dtype,
|
dtype=hf_config.torch_dtype,
|
||||||
device=torch.device("cuda"),
|
device=torch.device("cuda"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# GPU-only mode: pre-allocate policy metadata buffers
|
||||||
|
# This avoids dynamic GPU memory allocation during forward pass
|
||||||
|
if not config.enable_cpu_offload:
|
||||||
|
num_heads = hf_config.num_attention_heads // self.world_size
|
||||||
|
self.kvcache_manager.sparse_policy.alloc_policy_metadata(
|
||||||
|
num_heads=num_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
head_dim=head_dim,
|
||||||
|
max_seq_len=config.max_model_len,
|
||||||
|
dtype=hf_config.torch_dtype,
|
||||||
|
device=torch.device("cuda"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log policy info (handle both enum and None cases)
|
||||||
|
policy_name = config.sparse_policy.name if config.sparse_policy is not None else "FULL"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Sparse policy initialized: {config.sparse_policy.name} "
|
f"Sparse policy initialized: {policy_name} "
|
||||||
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
f"(topk={config.sparse_topk_blocks}, threshold={config.sparse_threshold_blocks})"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -375,7 +392,16 @@ class ModelRunner:
|
|||||||
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables)
|
set_context(
|
||||||
|
is_prefill=True,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
block_tables=block_tables,
|
||||||
|
kvcache_manager=getattr(self, 'kvcache_manager', None),
|
||||||
|
)
|
||||||
return input_ids, positions
|
return input_ids, positions
|
||||||
|
|
||||||
def prepare_decode(self, seqs: list[Sequence]):
|
def prepare_decode(self, seqs: list[Sequence]):
|
||||||
@@ -404,7 +430,13 @@ class ModelRunner:
|
|||||||
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True)
|
||||||
# Use GPU physical block tables for attention
|
# Use GPU physical block tables for attention
|
||||||
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
|
block_tables = self._prepare_gpu_block_tables(gpu_block_tables)
|
||||||
set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables)
|
set_context(
|
||||||
|
is_prefill=False,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
context_lens=context_lens,
|
||||||
|
block_tables=block_tables,
|
||||||
|
kvcache_manager=self.kvcache_manager,
|
||||||
|
)
|
||||||
return input_ids, positions
|
return input_ids, positions
|
||||||
|
|
||||||
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
|
def _prepare_gpu_block_tables(self, gpu_block_tables: list[list[int]]):
|
||||||
@@ -713,7 +745,13 @@ class ModelRunner:
|
|||||||
|
|
||||||
for bs in reversed(self.graph_bs):
|
for bs in reversed(self.graph_bs):
|
||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs])
|
set_context(
|
||||||
|
is_prefill=False,
|
||||||
|
slot_mapping=slot_mapping[:bs],
|
||||||
|
context_lens=context_lens[:bs],
|
||||||
|
block_tables=block_tables[:bs],
|
||||||
|
kvcache_manager=self.kvcache_manager,
|
||||||
|
)
|
||||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup
|
||||||
with torch.cuda.graph(graph, self.graph_pool):
|
with torch.cuda.graph(graph, self.graph_pool):
|
||||||
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
Factory function to create the appropriate KV cache manager.
|
Factory function to create the appropriate KV cache manager.
|
||||||
|
|
||||||
Decision logic:
|
Decision logic:
|
||||||
1. If enable_cpu_offload=False: use GPUOnlyManager
|
1. If enable_cpu_offload=False: use GPUOnlyManager (optionally with sparse policy)
|
||||||
2. If enable_cpu_offload=True but all blocks fit in GPU: use GPUOnlyManager
|
2. If enable_cpu_offload=True but all blocks fit in GPU: use GPUOnlyManager
|
||||||
3. If enable_cpu_offload=True and need CPU blocks: use HybridKVCacheManager
|
3. If enable_cpu_offload=True and need CPU blocks: use HybridKVCacheManager
|
||||||
|
|
||||||
@@ -37,9 +37,44 @@ def create_kvcache_manager(config: "Config") -> KVCacheManager:
|
|||||||
"""
|
"""
|
||||||
if not getattr(config, 'enable_cpu_offload', False):
|
if not getattr(config, 'enable_cpu_offload', False):
|
||||||
# Default: pure GPU mode
|
# Default: pure GPU mode
|
||||||
|
# Check if sparse policy is requested for GPU-only mode
|
||||||
|
from nanovllm.config import SparsePolicyType
|
||||||
|
sparse_policy_type = getattr(config, 'sparse_policy', None)
|
||||||
|
# Handle None case - use FULL as default
|
||||||
|
if sparse_policy_type is None:
|
||||||
|
sparse_policy_type = SparsePolicyType.FULL
|
||||||
|
|
||||||
|
sparse_policy = None
|
||||||
|
if sparse_policy_type != SparsePolicyType.FULL:
|
||||||
|
# Create sparse policy for GPU-only mode
|
||||||
|
from nanovllm.kvcache.sparse import create_sparse_policy
|
||||||
|
|
||||||
|
policy_kwargs = {}
|
||||||
|
if sparse_policy_type == SparsePolicyType.QUEST:
|
||||||
|
policy_kwargs = {
|
||||||
|
'topk_blocks': getattr(config, 'sparse_topk_blocks', 8),
|
||||||
|
'threshold_blocks': getattr(config, 'sparse_threshold_blocks', 4),
|
||||||
|
}
|
||||||
|
elif sparse_policy_type == SparsePolicyType.XATTN_BSA:
|
||||||
|
policy_kwargs = {
|
||||||
|
'block_size': getattr(config, 'sparse_block_size', 128),
|
||||||
|
'samples_per_chunk': getattr(config, 'sparse_samples_per_chunk', 128),
|
||||||
|
'threshold': getattr(config, 'sparse_threshold', 0.9),
|
||||||
|
'use_triton': getattr(config, 'sparse_use_triton', True),
|
||||||
|
'stride': getattr(config, 'sparse_stride', 8),
|
||||||
|
'chunk_size': getattr(config, 'sparse_chunk_size', 16384),
|
||||||
|
}
|
||||||
|
|
||||||
|
sparse_policy = create_sparse_policy(sparse_policy_type, **policy_kwargs)
|
||||||
|
else:
|
||||||
|
# FULL policy for GPU-only mode - always create for consistent API
|
||||||
|
from nanovllm.kvcache.sparse import FullAttentionPolicy
|
||||||
|
sparse_policy = FullAttentionPolicy()
|
||||||
|
|
||||||
return GPUOnlyManager(
|
return GPUOnlyManager(
|
||||||
num_blocks=config.num_kvcache_blocks,
|
num_blocks=config.num_kvcache_blocks,
|
||||||
block_size=config.kvcache_block_size,
|
block_size=config.kvcache_block_size,
|
||||||
|
sparse_policy=sparse_policy,
|
||||||
)
|
)
|
||||||
|
|
||||||
# CPU offload is enabled
|
# CPU offload is enabled
|
||||||
|
|||||||
@@ -7,13 +7,16 @@ the KVCacheManager interface.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import List, Tuple, Dict, Optional
|
from typing import List, Tuple, Dict, Optional, TYPE_CHECKING
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from nanovllm.engine.sequence import Sequence
|
from nanovllm.engine.sequence import Sequence
|
||||||
from nanovllm.kvcache.base_manager import KVCacheManager
|
from nanovllm.kvcache.base_manager import KVCacheManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from nanovllm.kvcache.sparse.policy import SparsePolicy
|
||||||
|
|
||||||
|
|
||||||
class Block:
|
class Block:
|
||||||
"""Physical block in GPU memory."""
|
"""Physical block in GPU memory."""
|
||||||
@@ -50,17 +53,28 @@ class GPUOnlyManager(KVCacheManager):
|
|||||||
all data stays on GPU at fixed addresses.
|
all data stays on GPU at fixed addresses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_blocks: int, block_size: int):
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
sparse_policy: Optional["SparsePolicy"] = None,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Initialize GPU-only manager.
|
Initialize GPU-only manager.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
num_blocks: Total number of blocks to manage
|
num_blocks: Total number of blocks to manage
|
||||||
block_size: Tokens per block (default 256)
|
block_size: Tokens per block (default 256)
|
||||||
|
sparse_policy: Optional sparse attention policy for GPU-only mode
|
||||||
"""
|
"""
|
||||||
self._block_size = block_size
|
self._block_size = block_size
|
||||||
self._num_blocks = num_blocks
|
self._num_blocks = num_blocks
|
||||||
|
|
||||||
|
# Sparse policy for GPU-only mode (optional)
|
||||||
|
self.sparse_policy = sparse_policy
|
||||||
|
# No offload engine in GPU-only mode
|
||||||
|
self.offload_engine = None
|
||||||
|
|
||||||
# Block metadata
|
# Block metadata
|
||||||
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
|
self.blocks: List[Block] = [Block(i) for i in range(num_blocks)]
|
||||||
|
|
||||||
|
|||||||
@@ -76,6 +76,75 @@ class FullAttentionPolicy(SparsePolicy):
|
|||||||
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
|
logger.info(f"[Full Policy] Density Stats: chunks={stats['num_chunks']}, "
|
||||||
f"blocks={stats['total_available_blocks']}, density=100.0%")
|
f"blocks={stats['total_available_blocks']}, density=100.0%")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# GPU-only methods (non-chunked)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
cu_seqlens_k: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
GPU-only prefill attention using flash_attn_varlen_func.
|
||||||
|
|
||||||
|
This is the simplest implementation - just call flash attention directly.
|
||||||
|
For sparse policies, this method would implement block selection.
|
||||||
|
"""
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
block_table=block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
GPU-only decode attention using flash_attn_with_kvcache.
|
||||||
|
|
||||||
|
This is the simplest implementation - just call flash attention directly.
|
||||||
|
For sparse policies, this method would implement block selection.
|
||||||
|
"""
|
||||||
|
from flash_attn import flash_attn_with_kvcache
|
||||||
|
|
||||||
|
# q is [batch, num_heads, head_dim], need to add seq dim
|
||||||
|
return flash_attn_with_kvcache(
|
||||||
|
q.unsqueeze(1), # [batch, 1, heads, dim]
|
||||||
|
k_cache,
|
||||||
|
v_cache,
|
||||||
|
cache_seqlens=cache_seqlens,
|
||||||
|
block_table=block_tables,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Chunked offload methods
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
def compute_chunked_prefill(
|
def compute_chunked_prefill(
|
||||||
self,
|
self,
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
|
|||||||
@@ -108,6 +108,34 @@ class SparsePolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def alloc_policy_metadata(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Pre-allocate GPU buffers for policy computation.
|
||||||
|
|
||||||
|
Called by the framework after KV cache allocation, but ONLY for GPU-only
|
||||||
|
mode (not CPU offload mode). Override this to pre-allocate buffers that
|
||||||
|
would otherwise be dynamically allocated during forward pass.
|
||||||
|
|
||||||
|
This is separate from initialize() which is used for CPU offload metadata.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_heads: Number of query heads
|
||||||
|
num_kv_heads: Number of KV heads (for GQA)
|
||||||
|
head_dim: Dimension per head
|
||||||
|
max_seq_len: Maximum sequence length (for buffer sizing)
|
||||||
|
dtype: Data type (typically float16/bfloat16)
|
||||||
|
device: Target device (cuda)
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
@@ -191,6 +219,87 @@ class SparsePolicy(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# GPU-only methods (non-chunked)
|
||||||
|
# These methods are used when all KV cache is on GPU, no CPU offload needed.
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
cu_seqlens_k: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute GPU-only prefill attention (non-chunked).
|
||||||
|
|
||||||
|
This method is used when all KV cache resides on GPU (no CPU offload).
|
||||||
|
Override this to implement sparse prefill attention for GPU-only mode.
|
||||||
|
Default implementation raises NotImplementedError.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: [total_q, num_heads, head_dim] query tensor (packed variable length)
|
||||||
|
k: [total_kv, num_kv_heads, head_dim] key tensor
|
||||||
|
v: [total_kv, num_kv_heads, head_dim] value tensor
|
||||||
|
cu_seqlens_q: [batch+1] cumulative sequence lengths for queries
|
||||||
|
cu_seqlens_k: [batch+1] cumulative sequence lengths for keys
|
||||||
|
max_seqlen_q: maximum query sequence length
|
||||||
|
max_seqlen_k: maximum key sequence length
|
||||||
|
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
|
||||||
|
layer_id: transformer layer index
|
||||||
|
block_tables: [batch, max_blocks] paged attention block tables (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[total_q, num_heads, head_dim] attention output
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not implement compute_prefill for GPU-only mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute GPU-only decode attention (non-chunked).
|
||||||
|
|
||||||
|
This method is used when all KV cache resides on GPU (no CPU offload).
|
||||||
|
Override this to implement sparse decode attention for GPU-only mode.
|
||||||
|
Default implementation raises NotImplementedError.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: [batch, num_heads, head_dim] query tensor (single token per sequence)
|
||||||
|
k_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged key cache
|
||||||
|
v_cache: [num_blocks, block_size, num_kv_heads, head_dim] paged value cache
|
||||||
|
cache_seqlens: [batch] sequence lengths in cache
|
||||||
|
softmax_scale: softmax scaling factor (1/sqrt(head_dim))
|
||||||
|
layer_id: transformer layer index
|
||||||
|
block_tables: [batch, max_blocks] paged attention block tables (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[batch, 1, num_heads, head_dim] attention output
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"{self.__class__.__name__} does not implement compute_decode for GPU-only mode"
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Chunked offload methods (for CPU offload mode)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def compute_chunked_prefill(
|
def compute_chunked_prefill(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -122,6 +122,271 @@ class XAttentionBSAPolicy(SparsePolicy):
|
|||||||
self._stats_total_selected_blocks = 0
|
self._stats_total_selected_blocks = 0
|
||||||
self._stats_num_chunks = 0
|
self._stats_num_chunks = 0
|
||||||
|
|
||||||
|
# Pre-allocated GQA expansion buffers (GPU-only mode)
|
||||||
|
# Set by alloc_policy_metadata(), None if not pre-allocated
|
||||||
|
self._k_expanded: torch.Tensor | None = None
|
||||||
|
self._v_expanded: torch.Tensor | None = None
|
||||||
|
self._max_seq_len: int = 0
|
||||||
|
|
||||||
|
def alloc_policy_metadata(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Pre-allocate GQA expansion buffers for GPU-only mode.
|
||||||
|
|
||||||
|
These buffers are used by compute_prefill() to avoid dynamic allocation
|
||||||
|
during forward pass. The buffers are sized for max_seq_len and sliced
|
||||||
|
to actual seq_len during use.
|
||||||
|
|
||||||
|
Memory usage: 2 * num_heads * max_seq_len * head_dim * dtype_size
|
||||||
|
For 64K seq, 32 heads, 128 dim, fp16: 2 * 32 * 65536 * 128 * 2 = 1 GB
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_heads: Number of query heads
|
||||||
|
num_kv_heads: Number of KV heads (for GQA)
|
||||||
|
head_dim: Dimension per head
|
||||||
|
max_seq_len: Maximum sequence length
|
||||||
|
dtype: Data type
|
||||||
|
device: Target device
|
||||||
|
"""
|
||||||
|
# Only allocate if GQA (num_heads != num_kv_heads)
|
||||||
|
if num_heads == num_kv_heads:
|
||||||
|
logger.info(f"[XAttn] No GQA expansion needed (num_heads == num_kv_heads = {num_heads})")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Shape: [1, num_heads, max_seq_len, head_dim] for xattn_estimate format
|
||||||
|
# Also used for BSA which expects [seq_len, num_heads, head_dim]
|
||||||
|
shape = (1, num_heads, max_seq_len, head_dim)
|
||||||
|
self._k_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
self._v_expanded = torch.empty(shape, dtype=dtype, device=device)
|
||||||
|
self._max_seq_len = max_seq_len
|
||||||
|
|
||||||
|
memory_mb = 2 * num_heads * max_seq_len * head_dim * dtype.itemsize / (1024 * 1024)
|
||||||
|
logger.info(f"[XAttn] Pre-allocated GQA buffers: shape={shape}, memory={memory_mb:.1f} MB")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# GPU-only methods (non-chunked)
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def compute_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
cu_seqlens_q: torch.Tensor,
|
||||||
|
cu_seqlens_k: torch.Tensor,
|
||||||
|
max_seqlen_q: int,
|
||||||
|
max_seqlen_k: int,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
GPU-only prefill attention using XAttention + BSA.
|
||||||
|
|
||||||
|
This method implements sparse attention for GPU-only mode:
|
||||||
|
1. Estimate block importance using xattn_estimate
|
||||||
|
2. Compute sparse attention using block_sparse_attn_func
|
||||||
|
|
||||||
|
Args:
|
||||||
|
q: Query tensor [total_q, num_heads, head_dim] (varlen packed)
|
||||||
|
k: Key tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
|
||||||
|
v: Value tensor [total_kv, num_kv_heads, head_dim] (varlen packed)
|
||||||
|
cu_seqlens_q: Cumulative sequence lengths for Q [batch+1]
|
||||||
|
cu_seqlens_k: Cumulative sequence lengths for K [batch+1]
|
||||||
|
max_seqlen_q: Maximum Q sequence length
|
||||||
|
max_seqlen_k: Maximum K sequence length
|
||||||
|
softmax_scale: Softmax scaling factor
|
||||||
|
layer_id: Transformer layer index
|
||||||
|
block_tables: Paged attention block tables (not used for XAttention)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Attention output [total_q, num_heads, head_dim]
|
||||||
|
"""
|
||||||
|
# When block_tables is provided (paged KV cache / prefix cache),
|
||||||
|
# fallback to flash_attn as XAttention expects contiguous K, V
|
||||||
|
if block_tables is not None:
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
block_table=block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not BSA_AVAILABLE:
|
||||||
|
# Fallback to flash attention if BSA not available
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not XATTN_AVAILABLE:
|
||||||
|
# Fallback to flash attention if xattn not available
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
from nanovllm.ops.xattn import xattn_estimate
|
||||||
|
|
||||||
|
# Get dimensions
|
||||||
|
total_q, num_heads, head_dim = q.shape
|
||||||
|
total_kv, num_kv_heads, _ = k.shape
|
||||||
|
|
||||||
|
# For now, assume batch_size = 1 (single sequence)
|
||||||
|
# TODO: Support batched varlen format
|
||||||
|
batch_size = cu_seqlens_q.shape[0] - 1
|
||||||
|
if batch_size != 1:
|
||||||
|
# Fallback to flash attention for batched input
|
||||||
|
from flash_attn import flash_attn_varlen_func
|
||||||
|
logger.warning(f"[XAttn] batch_size={batch_size} > 1, falling back to flash attention")
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
max_seqlen_k=max_seqlen_k,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
q_len = max_seqlen_q
|
||||||
|
k_len = max_seqlen_k
|
||||||
|
|
||||||
|
# Convert from varlen format [total, heads, dim] to [batch, heads, seq, dim]
|
||||||
|
# q: [q_len, num_heads, head_dim] -> [1, num_heads, q_len, head_dim]
|
||||||
|
Q = q.unsqueeze(0).transpose(1, 2) # [1, num_heads, q_len, head_dim]
|
||||||
|
K = k.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
|
||||||
|
V = v.unsqueeze(0).transpose(1, 2) # [1, num_kv_heads, k_len, head_dim]
|
||||||
|
|
||||||
|
# Expand KV for GQA - use pre-allocated buffers if available
|
||||||
|
if num_heads != num_kv_heads:
|
||||||
|
num_groups = num_heads // num_kv_heads
|
||||||
|
if self._k_expanded is not None and k_len <= self._max_seq_len:
|
||||||
|
# Use pre-allocated buffers with in-place expansion
|
||||||
|
K_exp = self._k_expanded[:, :, :k_len, :]
|
||||||
|
V_exp = self._v_expanded[:, :, :k_len, :]
|
||||||
|
# In-place GQA expansion: [1, num_kv_heads, k_len, head_dim] -> [1, num_heads, k_len, head_dim]
|
||||||
|
# Reshape K to [1, num_kv_heads, 1, k_len, head_dim] and broadcast to [1, num_kv_heads, num_groups, k_len, head_dim]
|
||||||
|
K_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
|
||||||
|
K.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
|
||||||
|
)
|
||||||
|
V_exp.view(1, num_kv_heads, num_groups, k_len, head_dim).copy_(
|
||||||
|
V.unsqueeze(2).expand(-1, -1, num_groups, -1, -1)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Fallback: dynamic allocation (when buffers not pre-allocated or seq too long)
|
||||||
|
K_exp, V_exp = expand_kv_for_gqa(K, V, num_heads)
|
||||||
|
else:
|
||||||
|
K_exp, V_exp = K, V
|
||||||
|
|
||||||
|
# Estimate block importance and get sparse mask
|
||||||
|
_, mask = xattn_estimate(
|
||||||
|
Q, K_exp,
|
||||||
|
chunk_size=self.chunk_size,
|
||||||
|
block_size=self.BSA_BLOCK_SIZE,
|
||||||
|
threshold=self.threshold,
|
||||||
|
use_triton=self.use_triton,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute block counts
|
||||||
|
q_block_num = (q_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||||
|
k_block_num = (k_len + self.BSA_BLOCK_SIZE - 1) // self.BSA_BLOCK_SIZE
|
||||||
|
|
||||||
|
# Prepare tensors for BSA
|
||||||
|
# q, k, v need to be [seq_len, num_heads, head_dim]
|
||||||
|
q_bsa = q # Already [q_len, num_heads, head_dim]
|
||||||
|
|
||||||
|
# For GQA with BSA, reuse the expanded K_exp, V_exp (convert to BSA format)
|
||||||
|
# K_exp: [1, num_heads, k_len, head_dim] -> [k_len, num_heads, head_dim]
|
||||||
|
if num_heads != num_kv_heads:
|
||||||
|
k_bsa = K_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||||
|
v_bsa = V_exp.squeeze(0).transpose(0, 1) # [k_len, num_heads, head_dim]
|
||||||
|
else:
|
||||||
|
k_bsa = k
|
||||||
|
v_bsa = v
|
||||||
|
|
||||||
|
# Prepare BSA inputs
|
||||||
|
cu_seqlens_q_bsa = torch.tensor([0, q_len], dtype=torch.int32, device=q.device)
|
||||||
|
cu_seqlens_k_bsa = torch.tensor([0, k_len], dtype=torch.int32, device=k.device)
|
||||||
|
head_groups = torch.ones(num_heads, dtype=torch.int32, device=q.device)
|
||||||
|
|
||||||
|
# Trim mask to actual block counts
|
||||||
|
mask_trimmed = mask[:, :, :q_block_num, :k_block_num].contiguous()
|
||||||
|
|
||||||
|
# Compute sparse attention using BSA
|
||||||
|
output = block_sparse_attn_func(
|
||||||
|
q_bsa, k_bsa, v_bsa,
|
||||||
|
cu_seqlens_q_bsa,
|
||||||
|
cu_seqlens_k_bsa,
|
||||||
|
head_groups,
|
||||||
|
None, # key_padding_mask
|
||||||
|
mask_trimmed,
|
||||||
|
q_len, k_len,
|
||||||
|
p_dropout=0.0,
|
||||||
|
deterministic=True,
|
||||||
|
is_causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update statistics (layer 0 only to avoid overcounting)
|
||||||
|
if layer_id == 0:
|
||||||
|
selected_blocks = mask_trimmed.sum().item()
|
||||||
|
total_blocks = q_block_num * k_block_num * num_heads
|
||||||
|
density = selected_blocks / total_blocks if total_blocks > 0 else 1.0
|
||||||
|
logger.debug(f"[XAttn GPU-only] layer={layer_id}, q_blocks={q_block_num}, "
|
||||||
|
f"k_blocks={k_block_num}, density={density:.1%}")
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def compute_decode(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
cache_seqlens: torch.Tensor,
|
||||||
|
softmax_scale: float,
|
||||||
|
layer_id: int,
|
||||||
|
block_tables: torch.Tensor = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
GPU-only decode attention - delegates to FullAttentionPolicy.
|
||||||
|
|
||||||
|
XAttention is designed for long prefill sequences. For decode (single token),
|
||||||
|
we use FullAttentionPolicy which calls flash_attn_with_kvcache.
|
||||||
|
"""
|
||||||
|
from nanovllm.kvcache.sparse.full_policy import FullAttentionPolicy
|
||||||
|
return FullAttentionPolicy().compute_decode(
|
||||||
|
q, k_cache, v_cache, cache_seqlens, softmax_scale, layer_id, block_tables
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# Chunked offload methods
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
def select_blocks(
|
def select_blocks(
|
||||||
self,
|
self,
|
||||||
available_blocks: List[int],
|
available_blocks: List[int],
|
||||||
|
|||||||
@@ -124,24 +124,47 @@ class Attention(nn.Module):
|
|||||||
if k_cache.numel() and v_cache.numel():
|
if k_cache.numel() and v_cache.numel():
|
||||||
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
store_kvcache(k, v, k_cache, v_cache, context.slot_mapping)
|
||||||
|
|
||||||
|
# Get sparse_policy from kvcache_manager (required, never None after warmup)
|
||||||
|
# During warmup, kvcache_manager is not yet allocated
|
||||||
|
if context.kvcache_manager is None:
|
||||||
|
# Warmup phase: use flash_attn directly
|
||||||
|
if context.is_prefill:
|
||||||
|
return flash_attn_varlen_func(
|
||||||
|
q, k, v,
|
||||||
|
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
||||||
|
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
||||||
|
softmax_scale=self.scale, causal=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return flash_attn_with_kvcache(
|
||||||
|
q.unsqueeze(1), k_cache, v_cache,
|
||||||
|
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
||||||
|
softmax_scale=self.scale, causal=True,
|
||||||
|
)
|
||||||
|
sparse_policy = context.kvcache_manager.sparse_policy
|
||||||
|
assert sparse_policy is not None, "sparse_policy must not be None"
|
||||||
|
|
||||||
if context.is_prefill:
|
if context.is_prefill:
|
||||||
if context.is_chunked_prefill:
|
if context.is_chunked_prefill:
|
||||||
# Chunked prefill: merge attention from previous KV
|
# Chunked prefill: merge attention from previous KV (CPU offload mode)
|
||||||
o = self._chunked_prefill_attention(q, k, v, context)
|
o = self._chunked_prefill_attention(q, k, v, context)
|
||||||
elif context.block_tables is not None: # prefix cache
|
|
||||||
k, v = k_cache, v_cache
|
|
||||||
o = flash_attn_varlen_func(q, k, v,
|
|
||||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
|
||||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
|
||||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
|
||||||
else:
|
else:
|
||||||
o = flash_attn_varlen_func(q, k, v,
|
# GPU-only mode: use policy for attention
|
||||||
max_seqlen_q=context.max_seqlen_q, cu_seqlens_q=context.cu_seqlens_q,
|
# Use paged attention if block_tables provided, else use k, v directly
|
||||||
max_seqlen_k=context.max_seqlen_k, cu_seqlens_k=context.cu_seqlens_k,
|
if context.block_tables is not None:
|
||||||
softmax_scale=self.scale, causal=True, block_table=context.block_tables)
|
k_for_attn, v_for_attn = k_cache, v_cache
|
||||||
|
else:
|
||||||
|
k_for_attn, v_for_attn = k, v
|
||||||
|
o = sparse_policy.compute_prefill(
|
||||||
|
q, k_for_attn, v_for_attn,
|
||||||
|
context.cu_seqlens_q, context.cu_seqlens_k,
|
||||||
|
context.max_seqlen_q, context.max_seqlen_k,
|
||||||
|
self.scale, self.layer_id,
|
||||||
|
context.block_tables,
|
||||||
|
)
|
||||||
else: # decode
|
else: # decode
|
||||||
if context.is_chunked_prefill:
|
if context.is_chunked_prefill:
|
||||||
# Chunked decode: need to load all KV from CPU+GPU
|
# Chunked decode: need to load all KV from CPU+GPU (CPU offload mode)
|
||||||
# Store current decode token to per-layer decode buffer
|
# Store current decode token to per-layer decode buffer
|
||||||
# This is needed because GPU cache has no layer dimension,
|
# This is needed because GPU cache has no layer dimension,
|
||||||
# so all layers would overwrite each other in decode_slot.
|
# so all layers would overwrite each other in decode_slot.
|
||||||
@@ -152,9 +175,12 @@ class Attention(nn.Module):
|
|||||||
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
|
offload_engine.write_to_decode_buffer(self.layer_id, pos_in_block, k.squeeze(0), v.squeeze(0))
|
||||||
o = self._chunked_decode_attention(q, k, v, context)
|
o = self._chunked_decode_attention(q, k, v, context)
|
||||||
else:
|
else:
|
||||||
o = flash_attn_with_kvcache(q.unsqueeze(1), k_cache, v_cache,
|
# GPU-only mode: use policy for attention
|
||||||
cache_seqlens=context.context_lens, block_table=context.block_tables,
|
o = sparse_policy.compute_decode(
|
||||||
softmax_scale=self.scale, causal=True)
|
q, k_cache, v_cache,
|
||||||
|
context.context_lens, self.scale, self.layer_id,
|
||||||
|
context.block_tables,
|
||||||
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
def _chunked_prefill_attention(
|
def _chunked_prefill_attention(
|
||||||
|
|||||||
158
scripts/profile.sh
Executable file
158
scripts/profile.sh
Executable file
@@ -0,0 +1,158 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# Profile bench.py using NVIDIA Nsight Systems (GPU-only mode)
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# bash scripts/profile.sh [options]
|
||||||
|
#
|
||||||
|
# Options:
|
||||||
|
# --max-len LENGTH Max sequence length (default: 32768)
|
||||||
|
# --policy POLICY Sparse policy: full, xattn (default: xattn)
|
||||||
|
# --gpu GPU_ID GPU to use (default: 0)
|
||||||
|
# --gpu-util UTIL GPU memory utilization (default: 0.9)
|
||||||
|
# --input-len LENGTH Input length (default: max-len - 1)
|
||||||
|
# --bench-decode Run decode benchmark instead of prefill
|
||||||
|
#
|
||||||
|
# Output:
|
||||||
|
# results/nsys/bench_<policy>_<max_len>_<timestamp>.nsys-rep
|
||||||
|
#
|
||||||
|
# Examples:
|
||||||
|
# bash scripts/profile.sh
|
||||||
|
# bash scripts/profile.sh --max-len 65536 --gpu-util 0.7
|
||||||
|
# bash scripts/profile.sh --policy full --max-len 32768
|
||||||
|
# bash scripts/profile.sh --bench-decode
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Default configuration
|
||||||
|
MAX_LEN="32768"
|
||||||
|
POLICY="xattn"
|
||||||
|
GPU_ID="0"
|
||||||
|
GPU_UTIL="0.9"
|
||||||
|
INPUT_LEN=""
|
||||||
|
BENCH_MODE="prefill"
|
||||||
|
|
||||||
|
# Parse arguments
|
||||||
|
while [[ $# -gt 0 ]]; do
|
||||||
|
case $1 in
|
||||||
|
--max-len)
|
||||||
|
MAX_LEN="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--policy)
|
||||||
|
POLICY="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--gpu)
|
||||||
|
GPU_ID="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--gpu-util)
|
||||||
|
GPU_UTIL="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--input-len)
|
||||||
|
INPUT_LEN="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--bench-decode)
|
||||||
|
BENCH_MODE="decode"
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
echo "Usage: $0 [options]"
|
||||||
|
echo ""
|
||||||
|
echo "Options:"
|
||||||
|
echo " --max-len LENGTH Max sequence length (default: 32768)"
|
||||||
|
echo " --policy POLICY Sparse policy: full, xattn (default: xattn)"
|
||||||
|
echo " --gpu GPU_ID GPU to use (default: 0)"
|
||||||
|
echo " --gpu-util UTIL GPU memory utilization (default: 0.9)"
|
||||||
|
echo " --input-len LENGTH Input length (default: max-len - 1)"
|
||||||
|
echo " --bench-decode Run decode benchmark instead of prefill"
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "Unknown option: $1"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
# Path configuration
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
||||||
|
OUTPUT_DIR="$PROJECT_ROOT/results/nsys"
|
||||||
|
BENCH_SCRIPT="$PROJECT_ROOT/bench.py"
|
||||||
|
|
||||||
|
# Create output directory if needed
|
||||||
|
mkdir -p "$OUTPUT_DIR"
|
||||||
|
|
||||||
|
# Generate timestamp for unique filename
|
||||||
|
TIMESTAMP=$(date +%Y%m%d_%H%M%S)
|
||||||
|
|
||||||
|
# Convert max_len to human-readable format (e.g., 32768 -> 32k)
|
||||||
|
if [ "$MAX_LEN" -ge 1024 ]; then
|
||||||
|
MAX_LEN_SUFFIX="$((MAX_LEN / 1024))k"
|
||||||
|
else
|
||||||
|
MAX_LEN_SUFFIX="${MAX_LEN}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
OUTPUT_FILE="$OUTPUT_DIR/bench_${POLICY}_${MAX_LEN_SUFFIX}_${BENCH_MODE}_${TIMESTAMP}"
|
||||||
|
|
||||||
|
# Build bench.py arguments
|
||||||
|
BENCH_ARGS="--max-len $MAX_LEN --gpu-util $GPU_UTIL"
|
||||||
|
|
||||||
|
if [ -n "$POLICY" ]; then
|
||||||
|
BENCH_ARGS="$BENCH_ARGS --policy $POLICY"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "$INPUT_LEN" ]; then
|
||||||
|
BENCH_ARGS="$BENCH_ARGS --input-len $INPUT_LEN"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$BENCH_MODE" = "decode" ]; then
|
||||||
|
BENCH_ARGS="$BENCH_ARGS --bench-decode"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "============================================================"
|
||||||
|
echo "NVIDIA Nsight Systems Profiling (GPU-only)"
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Bench script: $BENCH_SCRIPT"
|
||||||
|
echo "Policy: $POLICY"
|
||||||
|
echo "Max length: $MAX_LEN"
|
||||||
|
echo "GPU: $GPU_ID"
|
||||||
|
echo "GPU util: $GPU_UTIL"
|
||||||
|
echo "Bench mode: $BENCH_MODE"
|
||||||
|
echo "Output file: $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# nsys profile options:
|
||||||
|
# --trace=cuda,nvtx : Trace CUDA API and NVTX markers
|
||||||
|
# --force-overwrite=true : Overwrite existing output file
|
||||||
|
# --output=<path> : Output file path (without .nsys-rep extension)
|
||||||
|
|
||||||
|
echo "Running nsys profile..."
|
||||||
|
echo "Command: python bench.py $BENCH_ARGS"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
CUDA_VISIBLE_DEVICES=$GPU_ID PYTHONPATH="$PROJECT_ROOT:$PYTHONPATH" \
|
||||||
|
nsys profile \
|
||||||
|
--trace=cuda,nvtx \
|
||||||
|
--force-overwrite=true \
|
||||||
|
--output="$OUTPUT_FILE" \
|
||||||
|
python "$BENCH_SCRIPT" $BENCH_ARGS
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Profiling completed successfully!"
|
||||||
|
echo "============================================================"
|
||||||
|
echo "Output file: $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo ""
|
||||||
|
echo "To view results in GUI:"
|
||||||
|
echo " nsight-sys $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo ""
|
||||||
|
echo "To export statistics:"
|
||||||
|
echo " nsys stats --report cuda_api_sum $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo " nsys stats --report cuda_gpu_kern_sum $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo " nsys stats --report cuda_gpu_mem_size_sum $OUTPUT_FILE.nsys-rep"
|
||||||
|
echo "============================================================"
|
||||||
Reference in New Issue
Block a user